This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 6523457  [SYSTEMDS-3018] Costing of Federated Execution Plans
6523457 is described below

commit 65234573b986e8d0ba5bd27bd3d7641b921a7641
Author: sebwrede <[email protected]>
AuthorDate: Fri Jun 18 17:31:31 2021 +0200

    [SYSTEMDS-3018] Costing of Federated Execution Plans
    
    Closes #1367.
---
 src/main/java/org/apache/sysds/hops/Hop.java       |  93 +++++--
 .../codegen/opt/PlanSelectionFuseCostBasedV2.java  | 182 +-------------
 .../org/apache/sysds/hops/cost/ComputeCost.java    | 225 +++++++++++++++++
 .../sysds/hops/cost/CostEstimationWrapper.java     |  13 +-
 .../org/apache/sysds/hops/cost/CostEstimator.java  |  72 +++---
 .../hops/cost/CostEstimatorStaticRuntime.java      |  40 +--
 .../org/apache/sysds/hops/cost/FederatedCost.java  | 117 +++++++++
 .../sysds/hops/cost/FederatedCostEstimator.java    | 214 ++++++++++++++++
 .../apache/sysds/hops/rewrite/ProgramRewriter.java |   1 +
 .../hops/rewrite/RewriteFederatedExecution.java    | 133 ++++++++--
 .../rewrite/RewriteFederatedStatementBlocks.java   |  66 +++++
 .../runtime/instructions/FEDInstructionParser.java |   1 +
 .../fed/AggregateUnaryFEDInstruction.java          |  56 ++++-
 .../instructions/fed/AppendFEDInstruction.java     |   5 +-
 .../instructions/fed/CtableFEDInstruction.java     |   4 +-
 .../runtime/instructions/fed/FEDInstruction.java   |   3 +
 .../instructions/fed/ReorgFEDInstruction.java      |  36 ++-
 .../org/apache/sysds/test/AutomatedTestBase.java   |  26 ++
 .../fedplanning/FederatedCostEstimatorTest.java    | 279 +++++++++++++++++++++
 .../fedplanning/FederatedMultiplyPlanningTest.java |  33 +--
 .../fedplanning/BinaryCostEstimatorTest.dml        |  26 ++
 .../FederatedMultiplyCostEstimatorTest.dml         |  31 +++
 .../fedplanning/ForLoopCostEstimatorTest.dml       |  27 ++
 .../fedplanning/FunctionCostEstimatorTest.dml      |  28 +++
 .../fedplanning/IfElseCostEstimatorTest.dml        |  30 +++
 .../fedplanning/ParForLoopCostEstimatorTest.dml    |  27 ++
 .../privacy/fedplanning/WhileCostEstimatorTest.dml |  27 ++
 27 files changed, 1483 insertions(+), 312 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/Hop.java 
b/src/main/java/org/apache/sysds/hops/Hop.java
index ececf52..45fc3af 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -36,6 +36,7 @@ import org.apache.sysds.common.Types.OpOp2;
 import org.apache.sysds.common.Types.OpOpData;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.cost.FederatedCost;
 import org.apache.sysds.hops.recompile.Recompiler;
 import org.apache.sysds.hops.recompile.Recompiler.ResetType;
 import org.apache.sysds.lops.CSVReBlock;
@@ -91,8 +92,9 @@ public abstract class Hop implements ParseInfo {
         * If it is lout, the output should be retrieved by the coordinator.
         */
        protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
+       protected FederatedCost _federatedCost = new FederatedCost();
        
-       // Estimated size for the output produced from this Hop
+       // Estimated size for the output produced from this Hop in bytes
        protected double _outputMemEstimate = OptimizerUtils.INVALID_SIZE;
        
        // Estimated size for the entire operation represented by this Hop
@@ -535,7 +537,7 @@ public abstract class Hop implements ParseInfo {
         * only use getMemEstimate(), which gives memory required to store 
         * all inputs and the output.
         * 
-        * @return output size memory estimate
+        * @return output size memory estimate in bytes
         */
        protected double getOutputSize() {
                return _outputMemEstimate;
@@ -545,14 +547,22 @@ public abstract class Hop implements ParseInfo {
                return getInputSize(null);
        }
 
-       protected double getInputSize(Collection<String> exclVars) {
+       /**
+        * Get the memory estimate of inputs as the sum of input estimates in 
bytes.
+        * @param exclVars name of input hops to exclude from the input estimate
+        * @param injectedDefault default memory estimate (bytes) used when the 
memory estimate of the input is negative
+        * @return input memory estimate in bytes
+        */
+       protected double getInputSize(Collection<String> exclVars, double 
injectedDefault){
                double sum = 0;
                int len = _input.size();
                for( int i=0; i<len; i++ ) { //for all inputs
                        Hop hi = _input.get(i);
                        if( exclVars != null && exclVars.contains(hi.getName()) 
)
                                continue;
-                       double hmout = hi.getOutputMemEstimate();
+                       double hmout = hi.getOutputMemEstimate(injectedDefault);
+                       if (hmout < 0)
+                               hmout = 
injectedDefault*(Math.max(hi.getDim1(),1) * Math.max(hi.getDim2(),1));
                        if( hmout > 1024*1024 ) {//for relevant sizes
                                //check if already included in estimate (if an 
input is used
                                //multiple times it is still only required once 
in memory)
@@ -564,10 +574,19 @@ public abstract class Hop implements ParseInfo {
                        }
                        sum += hmout;
                }
-               
+
                return sum;
        }
 
+       /**
+        * Get the memory estimate of inputs as the sum of input estimates in 
bytes.
+        * @param exclVars name of input hops to exclude from the input estimate
+        * @return input memory estimate in bytes
+        */
+       protected double getInputSize(Collection<String> exclVars) {
+               return getInputSize(exclVars, OptimizerUtils.INVALID_SIZE);
+       }
+
        protected double getInputSize( int pos ){
                double ret = 0;
                if( _input.size()>pos )
@@ -582,12 +601,11 @@ public abstract class Hop implements ParseInfo {
        /**
         * NOTES:
         * * Purpose: Whenever the output dimensions / sparsity of a hop are 
unknown, this hop
-        *   should store its worst-case output statistics (if known) in that 
table. Subsequent
-        *   hops can then
+        *   should store its worst-case output statistics (if known) in that 
table.
         * * Invocation: Intended to be called for ALL root nodes of one Hops 
DAG with the same
         *   (initially empty) memo table.
         * 
-        * @return memory estimate
+        * @return memory estimate in bytes
         */
        public double getMemEstimate() {
                if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
@@ -620,17 +638,46 @@ public abstract class Hop implements ParseInfo {
        }
 
        //wrappers for meaningful public names to memory estimates.
-       
+
+       /**
+        * Get the memory estimate of inputs as the sum of input estimates in 
bytes.
+        * @return input memory estimate in bytes
+        */
        public double getInputMemEstimate()
        {
                return getInputSize();
        }
-       
+
+       /**
+        * Get the memory estimate of inputs as the sum of input estimates in 
bytes.
+        * @param injectedDefault default memory estimate (bytes) used when the 
memory estimate of the input is negative
+        * @return input memory estimate in bytes
+        */
+       public double getInputMemEstimate(double injectedDefault){
+               return getInputSize(null, injectedDefault);
+       }
+
+       /**
+        * Output memory estimate in bytes.
+        * @return output memory estimate in bytes
+        */
        public double getOutputMemEstimate()
        {
                return getOutputSize();
        }
 
+       /**
+        * Output memory estimate in bytes with negative memory estimates 
replaced by the injected default.
+        * The injected default represents the memory estimate per output cell, 
hence it is multiplied by the estimated
+        * dimensions of the output of the hop.
+        * @param injectedDefault memory estimate to be returned in case the 
memory estimate defaults to a negative number
+        * @return output memory estimate in bytes
+        */
+       public double getOutputMemEstimate(double injectedDefault)
+       {
+               return 
Math.max(getOutputMemEstimate(),injectedDefault*(Math.max(getDim1(),1) * 
Math.max(getDim2(),1)));
+       }
+
        public double getIntermediateMemEstimate()
        {
                return getIntermediateSize();
@@ -823,17 +870,13 @@ public abstract class Hop implements ParseInfo {
         * This method only has an effect if FEDERATED_COMPILATION is activated.
         */
        protected void updateETFed(){
-               if ( _federatedOutput == FederatedOutput.FOUT || 
_federatedOutput == FederatedOutput.LOUT )
+               if ( _federatedOutput.isForced() )
                        _etype = ExecType.FED;
        }
        
        public boolean isFederated(){
                return getExecType() == ExecType.FED;
        }
-       
-       public boolean isFederatedOutput(){
-               return _federatedOutput == FederatedOutput.FOUT;
-       }
 
        public boolean someInputFederated(){
                return getInput().stream().anyMatch(Hop::hasFederatedOutput);
@@ -889,6 +932,26 @@ public abstract class Hop implements ParseInfo {
                return _federatedOutput == FederatedOutput.FOUT;
        }
 
+       public boolean hasLocalOutput(){
+               return _federatedOutput == FederatedOutput.LOUT;
+       }
+
+       /**
+        * Check if federated cost has been initialized for this Hop.
+        * @return true if federated cost has been initialized
+        */
+       public boolean federatedCostInitialized(){
+               return _federatedCost.getTotal() > 0;
+       }
+
+       public FederatedCost getFederatedCost(){
+               return _federatedCost;
+       }
+
+       public void setFederatedCost(FederatedCost cost){
+               _federatedCost = cost;
+       }
+
        public void setUpdateType(UpdateType update){
                _updateType = update;
        }
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
 
b/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 0b20876..c6cfe9f 100644
--- 
a/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ 
b/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -46,17 +46,8 @@ import org.apache.sysds.common.Types.OpOpData;
 import org.apache.sysds.common.Types.OpOpN;
 import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.hops.AggUnaryOp;
-import org.apache.sysds.hops.BinaryOp;
-import org.apache.sysds.hops.DnnOp;
 import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.IndexingOp;
-import org.apache.sysds.hops.LiteralOp;
-import org.apache.sysds.hops.NaryOp;
 import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.hops.ParameterizedBuiltinOp;
-import org.apache.sysds.hops.ReorgOp;
-import org.apache.sysds.hops.TernaryOp;
-import org.apache.sysds.hops.UnaryOp;
 import org.apache.sysds.hops.codegen.opt.ReachabilityGraph.SubProblem;
 import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
 import org.apache.sysds.hops.codegen.template.TemplateOuterProduct;
@@ -64,6 +55,7 @@ import org.apache.sysds.hops.codegen.template.TemplateRow;
 import org.apache.sysds.hops.codegen.template.TemplateUtils;
 import org.apache.sysds.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
 import org.apache.sysds.hops.codegen.template.TemplateBase.TemplateType;
+import org.apache.sysds.hops.cost.ComputeCost;
 import org.apache.sysds.hops.rewrite.HopRewriteUtils;
 import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
 import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer;
@@ -175,8 +167,8 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                        //obtain hop compute costs per cell once
                        HashMap<Long, Double> computeCosts = new HashMap<>();
                        for( Long hopID : part.getPartition() )
-                               getComputeCosts(memo.getHopRefs().get(hopID), 
computeCosts);
-                       
+                               computeCosts.put(hopID, 
ComputeCost.getHOPComputeCost(memo.getHopRefs().get(hopID)));
+
                        //prepare pruning helpers and prune memo table w/ 
determined mat points
                        StaticCosts costs = new StaticCosts(computeCosts, 
sumComputeCost(computeCosts),
                                getReadCost(part, memo), 
getWriteCost(part.getRoots(), memo), minOuterSparsity(part, memo));
@@ -1011,174 +1003,6 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                return costs;
        }
        
-       private static void getComputeCosts(Hop current, HashMap<Long, Double> 
computeCosts) 
-       {
-               //get costs for given hop
-               double costs = 1;
-               if( current instanceof UnaryOp ) {
-                       switch( ((UnaryOp)current).getOp() ) {
-                               case ABS:
-                               case ROUND:
-                               case CEIL:
-                               case FLOOR:
-                               case SIGN:    costs = 1; break; 
-                               case SPROP:
-                               case SQRT:    costs = 2; break;
-                               case EXP:     costs = 18; break;
-                               case SIGMOID: costs = 21; break;
-                               case LOG:
-                               case LOG_NZ:  costs = 32; break;
-                               case NCOL:
-                               case NROW:
-                               case PRINT:
-                               case ASSERT:
-                               case CAST_AS_BOOLEAN:
-                               case CAST_AS_DOUBLE:
-                               case CAST_AS_INT:
-                               case CAST_AS_MATRIX:
-                               case CAST_AS_SCALAR: costs = 1; break;
-                               case SIN:     costs = 18; break;
-                               case COS:     costs = 22; break;
-                               case TAN:     costs = 42; break;
-                               case ASIN:    costs = 93; break;
-                               case ACOS:    costs = 103; break;
-                               case ATAN:    costs = 40; break;
-                               case SINH:    costs = 93; break; // TODO:
-                               case COSH:    costs = 103; break;
-                               case TANH:    costs = 40; break;
-                               case CUMSUM:
-                               case CUMMIN:
-                               case CUMMAX:
-                               case CUMPROD: costs = 1; break;
-                               case CUMSUMPROD: costs = 2; break;
-                               default:
-                                       LOG.warn("Cost model not "
-                                               + "implemented yet for: 
"+((UnaryOp)current).getOp());
-                       }
-               }
-               else if( current instanceof BinaryOp ) {
-                       switch( ((BinaryOp)current).getOp() ) {
-                               case MULT: 
-                               case PLUS:
-                               case MINUS:
-                               case MIN:
-                               case MAX: 
-                               case AND:
-                               case OR:
-                               case EQUAL:
-                               case NOTEQUAL:
-                               case LESS:
-                               case LESSEQUAL:
-                               case GREATER:
-                               case GREATEREQUAL: 
-                               case CBIND:
-                               case RBIND:   costs = 1; break;
-                               case INTDIV:  costs = 6; break;
-                               case MODULUS: costs = 8; break;
-                               case DIV:     costs = 22; break;
-                               case LOG:
-                               case LOG_NZ:  costs = 32; break;
-                               case POW:     costs = 
(HopRewriteUtils.isLiteralOfValue(
-                                               current.getInput().get(1), 2) ? 
1 : 16); break;
-                               case MINUS_NZ:
-                               case MINUS1_MULT: costs = 2; break;
-                               case MOMENT:
-                                       int type = (int) 
(current.getInput().get(1) instanceof LiteralOp ? 
-                                               
HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
-                                       switch( type ) {
-                                               case 0: costs = 1; break; 
//count
-                                               case 1: costs = 8; break; //mean
-                                               case 2: costs = 16; break; //cm2
-                                               case 3: costs = 31; break; //cm3
-                                               case 4: costs = 51; break; //cm4
-                                               case 5: costs = 16; break; 
//variance
-                                       }
-                                       break;
-                               case COV: costs = 23; break;
-                               default:
-                                       LOG.warn("Cost model not "
-                                               + "implemented yet for: 
"+((BinaryOp)current).getOp());
-                       }
-               }
-               else if( current instanceof TernaryOp ) {
-                       switch( ((TernaryOp)current).getOp() ) {
-                               case IFELSE:
-                               case PLUS_MULT: 
-                               case MINUS_MULT: costs = 2; break;
-                               case CTABLE:     costs = 3; break;
-                               case MOMENT:
-                                       int type = (int) 
(current.getInput().get(1) instanceof LiteralOp ? 
-                                               
HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
-                                       switch( type ) {
-                                               case 0: costs = 2; break; 
//count
-                                               case 1: costs = 9; break; //mean
-                                               case 2: costs = 17; break; //cm2
-                                               case 3: costs = 32; break; //cm3
-                                               case 4: costs = 52; break; //cm4
-                                               case 5: costs = 17; break; 
//variance
-                                       }
-                                       break;
-                               case COV: costs = 23; break;
-                               default:
-                                       LOG.warn("Cost model not "
-                                               + "implemented yet for: 
"+((TernaryOp)current).getOp());
-                       }
-               }
-               else if( current instanceof NaryOp ) {
-                       costs = HopRewriteUtils.isNary(current, OpOpN.MIN, 
OpOpN.MAX, OpOpN.PLUS) ?
-                               current.getInput().size() : 1;
-               }
-               else if( current instanceof ParameterizedBuiltinOp ) {
-                       costs = 1;
-               }
-               else if( current instanceof IndexingOp ) {
-                       costs = 1;
-               }
-               else if( current instanceof ReorgOp ) {
-                       costs = 1;
-               }
-               else if( current instanceof DnnOp ) {
-                       switch( ((DnnOp)current).getOp() ) {
-                               case BIASADD:
-                               case BIASMULT:
-                                       costs = 2;
-                               default:
-                                       LOG.warn("Cost model not "
-                                               + "implemented yet for: 
"+((DnnOp)current).getOp());
-                       }
-               }
-               else if( current instanceof AggBinaryOp ) {
-                       //outer product template w/ matrix-matrix 
-                       //or row template w/ matrix-vector or matrix-matrix
-                       costs = 2 * current.getInput().get(0).getDim2();
-                       if( current.getInput().get(0).dimsKnown(true) )
-                               costs *= 
current.getInput().get(0).getSparsity();
-               }
-               else if( current instanceof AggUnaryOp) {
-                       switch(((AggUnaryOp)current).getOp()) {
-                               case SUM:    costs = 4; break; 
-                               case SUM_SQ: costs = 5; break;
-                               case MIN:
-                               case MAX:    costs = 1; break;
-                               default:
-                                       LOG.warn("Cost model not "
-                                               + "implemented yet for: 
"+((AggUnaryOp)current).getOp());
-                       }
-                       switch(((AggUnaryOp)current).getDirection()) {
-                               case Col: costs *= 
Math.max(current.getInput().get(0).getDim1(),1); break;
-                               case Row: costs *= 
Math.max(current.getInput().get(0).getDim2(),1); break;
-                               case RowCol: costs *= 
getSize(current.getInput().get(0)); break;
-                       }
-               }
-               
-               //scale by current output size in order to correctly reflect
-               //a mix of row and cell operations in the same fused operator
-               //(e.g., row template with fused column vector operations)
-               costs *= getSize(current);
-               
-               computeCosts.put(current.getHopID(), costs);
-       }
-       
        private static boolean hasNoRefToMatPoint(long hopID, 
                        MemoTableEntry me, InterestingPoint[] M, boolean[] 
plan) {
                return !InterestingPoint.isMatPoint(M, hopID, me, plan);
diff --git a/src/main/java/org/apache/sysds/hops/cost/ComputeCost.java 
b/src/main/java/org/apache/sysds/hops/cost/ComputeCost.java
new file mode 100644
index 0000000..3ac64b6
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/ComputeCost.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DnnOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.IndexingOp;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.NaryOp;
+import org.apache.sysds.hops.ParameterizedBuiltinOp;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
+import org.apache.sysds.hops.UnaryOp;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+
+/**
+ * Class with methods estimating compute costs of operations.
+ */
+public class ComputeCost {
+       private static final Log LOG = 
LogFactory.getLog(ComputeCost.class.getName());
+
+       /**
+        * Get compute cost for given HOP based on the number of floating point 
operations per output cell
+        * and the total number of output cells.
+        * @param currentHop for which compute cost is returned
+        * @return compute cost of currentHop as number of floating point 
operations
+        */
+       public static double getHOPComputeCost(Hop currentHop){
+               double costs = 1;
+               if( currentHop instanceof UnaryOp) {
+                       switch( ((UnaryOp)currentHop).getOp() ) {
+                               case ABS:
+                               case ROUND:
+                               case CEIL:
+                               case FLOOR:
+                               case SIGN:    costs = 1; break;
+                               case SPROP:
+                               case SQRT:    costs = 2; break;
+                               case EXP:     costs = 18; break;
+                               case SIGMOID: costs = 21; break;
+                               case LOG:
+                               case LOG_NZ:  costs = 32; break;
+                               case NCOL:
+                               case NROW:
+                               case PRINT:
+                               case ASSERT:
+                               case CAST_AS_BOOLEAN:
+                               case CAST_AS_DOUBLE:
+                               case CAST_AS_INT:
+                               case CAST_AS_MATRIX:
+                               case CAST_AS_SCALAR: costs = 1; break;
+                               case SIN:     costs = 18; break;
+                               case COS:     costs = 22; break;
+                               case TAN:     costs = 42; break;
+                               case ASIN:    costs = 93; break;
+                               case ACOS:    costs = 103; break;
+                               case ATAN:    costs = 40; break;
+                               case SINH:    costs = 93; break; // TODO:
+                               case COSH:    costs = 103; break;
+                               case TANH:    costs = 40; break;
+                               case CUMSUM:
+                               case CUMMIN:
+                               case CUMMAX:
+                               case CUMPROD: costs = 1; break;
+                               case CUMSUMPROD: costs = 2; break;
+                               default:
+                                       LOG.warn("Cost model not "
+                                               + "implemented yet for: 
"+((UnaryOp)currentHop).getOp());
+                       }
+               }
+               else if( currentHop instanceof BinaryOp) {
+                       switch( ((BinaryOp)currentHop).getOp() ) {
+                               case MULT:
+                               case PLUS:
+                               case MINUS:
+                               case MIN:
+                               case MAX:
+                               case AND:
+                               case OR:
+                               case EQUAL:
+                               case NOTEQUAL:
+                               case LESS:
+                               case LESSEQUAL:
+                               case GREATER:
+                               case GREATEREQUAL:
+                               case CBIND:
+                               case RBIND:   costs = 1; break;
+                               case INTDIV:  costs = 6; break;
+                               case MODULUS: costs = 8; break;
+                               case DIV:     costs = 22; break;
+                               case LOG:
+                               case LOG_NZ:  costs = 32; break;
+                               case POW:     costs = 
(HopRewriteUtils.isLiteralOfValue(
+                                       currentHop.getInput().get(1), 2) ? 1 : 
16); break;
+                               case MINUS_NZ:
+                               case MINUS1_MULT: costs = 2; break;
+                               case MOMENT:
+                                       int type = (int) 
(currentHop.getInput().get(1) instanceof LiteralOp ?
+                                               
HopRewriteUtils.getIntValueSafe((LiteralOp)currentHop.getInput().get(1)) : 2);
+                                       switch( type ) {
+                                               case 0: costs = 1; break; 
//count
+                                               case 1: costs = 8; break; //mean
+                                               case 2: costs = 16; break; //cm2
+                                               case 3: costs = 31; break; //cm3
+                                               case 4: costs = 51; break; //cm4
+                                               case 5: costs = 16; break; 
//variance
+                                       }
+                                       break;
+                               case COV: costs = 23; break;
+                               default:
+                                       LOG.warn("Cost model not "
+                                               + "implemented yet for: 
"+((BinaryOp)currentHop).getOp());
+                       }
+               }
+               else if( currentHop instanceof TernaryOp) {
+                       switch( ((TernaryOp)currentHop).getOp() ) {
+                               case IFELSE:
+                               case PLUS_MULT:
+                               case MINUS_MULT: costs = 2; break;
+                               case CTABLE:     costs = 3; break;
+                               case MOMENT:
+                                       int type = (int) 
(currentHop.getInput().get(1) instanceof LiteralOp ?
+                                               
HopRewriteUtils.getIntValueSafe((LiteralOp)currentHop.getInput().get(1)) : 2);
+                                       switch( type ) {
+                                               case 0: costs = 2; break; 
//count
+                                               case 1: costs = 9; break; //mean
+                                               case 2: costs = 17; break; //cm2
+                                               case 3: costs = 32; break; //cm3
+                                               case 4: costs = 52; break; //cm4
+                                               case 5: costs = 17; break; 
//variance
+                                       }
+                                       break;
+                               case COV: costs = 23; break;
+                               default:
+                                       LOG.warn("Cost model not "
+                                               + "implemented yet for: 
"+((TernaryOp)currentHop).getOp());
+                       }
+               }
+               else if( currentHop instanceof NaryOp) {
+                       costs = HopRewriteUtils.isNary(currentHop, 
Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) ?
+                               currentHop.getInput().size() : 1;
+               }
+               else if( currentHop instanceof ParameterizedBuiltinOp) {
+                       costs = 1;
+               }
+               else if( currentHop instanceof IndexingOp) {
+                       costs = 1;
+               }
+               else if( currentHop instanceof ReorgOp) {
+                       costs = 1;
+               }
+               else if( currentHop instanceof DnnOp) {
+                       switch( ((DnnOp)currentHop).getOp() ) {
+                               case BIASADD:
+                               case BIASMULT:
+                                       costs = 2;
+                               default:
+                                       LOG.warn("Cost model not "
+                                               + "implemented yet for: 
"+((DnnOp)currentHop).getOp());
+                       }
+               }
+               else if( currentHop instanceof AggBinaryOp) {
+                       //outer product template w/ matrix-matrix
+                       //or row template w/ matrix-vector or matrix-matrix
+                       costs = 2 * currentHop.getInput().get(0).getDim2();
+                       if( currentHop.getInput().get(0).dimsKnown(true) )
+                               costs *= 
currentHop.getInput().get(0).getSparsity();
+               }
+               else if( currentHop instanceof AggUnaryOp) {
+                       switch(((AggUnaryOp)currentHop).getOp()) {
+                               case SUM:    costs = 4; break;
+                               case SUM_SQ: costs = 5; break;
+                               case MIN:
+                               case MAX:    costs = 1; break;
+                               default:
+                                       LOG.warn("Cost model not "
+                                               + "implemented yet for: 
"+((AggUnaryOp)currentHop).getOp());
+                       }
+                       switch(((AggUnaryOp)currentHop).getDirection()) {
+                               case Col: costs *= 
Math.max(currentHop.getInput().get(0).getDim1(),1); break;
+                               case Row: costs *= 
Math.max(currentHop.getInput().get(0).getDim2(),1); break;
+                               case RowCol: costs *= 
getSize(currentHop.getInput().get(0)); break;
+                       }
+               }
+
+               //scale by current output size in order to correctly reflect
+               //a mix of row and cell operations in the same fused operator
+               //(e.g., row template with fused column vector operations)
+               costs *= getSize(currentHop);
+               return costs;
+       }
+
+       /**
+        * Get number of output cells of given hop.
+        * @param hop for which the number of output cells are found
+        * @return number of output cells of given hop
+        */
+       private static long getSize(Hop hop) {
+               return Math.max(hop.getDim1(),1)
+                       * Math.max(hop.getDim2(),1);
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java 
b/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java
index f8d6a2d..23fcf51 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java
@@ -32,7 +32,6 @@ import 
org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 
 public class CostEstimationWrapper 
 {
-       
        public enum CostType { 
                NUM_MRJOBS, //based on number of MR jobs, [number MR jobs]
                STATIC // based on FLOPS, read/write, etc, [time in sec]
@@ -44,17 +43,13 @@ public class CostEstimationWrapper
        private static CostEstimator _costEstim = null;
        
        
-       static 
-       {
-
+       static  {
                //create cost estimator
-               try
-               {
+               try {
                        //TODO config parameter?
                        _costEstim = createCostEstimator(DEFAULT_COSTTYPE);
                }
-               catch(Exception ex)
-               {
+               catch(Exception ex) {
                        LOG.error("Failed cost estimator initialization.", ex);
                }
        }
@@ -89,5 +84,5 @@ public class CostEstimationWrapper
                        default:
                                throw new DMLRuntimeException("Unknown cost 
type: "+type);
                }
-       }       
+       }
 }
diff --git a/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java 
b/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
index bb8753c..03948d4 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
@@ -58,9 +58,8 @@ import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
 
-public abstract class CostEstimator 
+public abstract class CostEstimator
 {
-       
        protected static final Log LOG = 
LogFactory.getLog(CostEstimator.class.getName());
        
        private static final int DEFAULT_NUMITER = 15;
@@ -84,7 +83,7 @@ public abstract class CostEstimator
        public double getTimeEstimate(ProgramBlock pb, LocalVariableMap vars, 
HashMap<String,VarStats> stats, boolean recursive) {
                //obtain stats from symboltable (e.g., during recompile)
                maintainVariableStatistics(vars, stats);
-                               
+               
                //get cost estimate
                return rGetTimeEstimate(pb, stats, new HashSet<String>(), 
recursive);
        }
@@ -281,10 +280,8 @@ public abstract class CostEstimator
                VarStats[] vs = new VarStats[3];
                String[] attr = null; 
 
-               if( inst instanceof UnaryCPInstruction )
-               {
-                       if( inst instanceof DataGenCPInstruction )
-                       {
+               if( inst instanceof UnaryCPInstruction ) {
+                       if( inst instanceof DataGenCPInstruction ) {
                                DataGenCPInstruction rinst = 
(DataGenCPInstruction) inst;
                                vs[0] = _unknownStats;
                                vs[1] = _unknownStats;
@@ -298,15 +295,13 @@ public abstract class CostEstimator
                                        type = 1;
                                attr = new String[]{String.valueOf(type)};
                        }
-                       else if( inst instanceof StringInitCPInstruction )
-                       {
+                       else if( inst instanceof StringInitCPInstruction ) {
                                StringInitCPInstruction rinst = 
(StringInitCPInstruction) inst;
                                vs[0] = _unknownStats;
                                vs[1] = _unknownStats;
                                vs[2] = stats.get( rinst.output.getName() );
                        }
-                       else //general unary
-                       {
+                       else { //general unary
                                UnaryCPInstruction uinst = (UnaryCPInstruction) 
inst;
                                vs[0] = stats.get( uinst.input1.getName() );
                                vs[1] = _unknownStats;
@@ -317,69 +312,61 @@ public abstract class CostEstimator
                                if( vs[2] == null ) //scalar output
                                        vs[2] = _scalarStats;
                                
-                               if( inst instanceof MMTSJCPInstruction )
-                               {
+                               if( inst instanceof MMTSJCPInstruction ) {
                                        String type = 
((MMTSJCPInstruction)inst).getMMTSJType().toString();
                                        attr = new String[]{type};
                                } 
-                               else if( inst instanceof 
AggregateUnaryCPInstruction )
-                               {
+                               else if( inst instanceof 
AggregateUnaryCPInstruction ) {
                                        String[] parts = 
InstructionUtils.getInstructionParts(inst.toString());
                                        String opcode = parts[0];
                                        if( opcode.equals("cm") )
-                                               attr = new 
String[]{parts[parts.length-2]};                                             
-                               } 
+                                               attr = new 
String[]{parts[parts.length-2]};
+                               }
                        }
                }
-               else if( inst instanceof BinaryCPInstruction )
-               {
+               else if( inst instanceof BinaryCPInstruction ) {
                        BinaryCPInstruction binst = (BinaryCPInstruction) inst;
                        vs[0] = stats.get( binst.input1.getName() );
                        vs[1] = stats.get( binst.input2.getName() );
                        vs[2] = stats.get( binst.output.getName() );
                        
-                       
-                       if( vs[0] == null ) //scalar input, 
+                       if( vs[0] == null ) //scalar input,
                                vs[0] = _scalarStats;
-                       if( vs[1] == null ) //scalar input, 
+                       if( vs[1] == null ) //scalar input,
                                vs[1] = _scalarStats;
                        if( vs[2] == null ) //scalar output
                                vs[2] = _scalarStats;
-               }       
-               else if( inst instanceof AggregateTernaryCPInstruction )
-               {
+               }
+               else if( inst instanceof AggregateTernaryCPInstruction ) {
                        AggregateTernaryCPInstruction binst = 
(AggregateTernaryCPInstruction) inst;
                        //of same dimension anyway but missing third input
-                       vs[0] = stats.get( binst.input1.getName() ); 
+                       vs[0] = stats.get( binst.input1.getName() );
                        vs[1] = stats.get( binst.input2.getName() );
                        vs[2] = stats.get( binst.output.getName() );
                                
-                       if( vs[0] == null ) //scalar input, 
+                       if( vs[0] == null ) //scalar input,
                                vs[0] = _scalarStats;
-                       if( vs[1] == null ) //scalar input, 
+                       if( vs[1] == null ) //scalar input,
                                vs[1] = _scalarStats;
                        if( vs[2] == null ) //scalar output
                                vs[2] = _scalarStats;
                }
-               else if( inst instanceof ParameterizedBuiltinCPInstruction )
-               {
+               else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
                        //ParameterizedBuiltinCPInstruction pinst = 
(ParameterizedBuiltinCPInstruction) inst;
                        String[] parts = 
InstructionUtils.getInstructionParts(inst.toString());
                        String opcode = parts[0];
-                       if( opcode.equals("groupedagg") )
-                       {                               
+                       if( opcode.equals("groupedagg") ) {
                                HashMap<String,String> paramsMap = 
ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
                                String fn = paramsMap.get("fn");
                                String order = paramsMap.get("order");
                                AggregateOperationTypes type = 
CMOperator.getAggOpType(fn, order);
                                attr = new 
String[]{String.valueOf(type.ordinal())};
                        }
-                       else if( opcode.equals("rmempty") )
-                       {
+                       else if( opcode.equals("rmempty") ) {
                                HashMap<String,String> paramsMap = 
ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
                                attr = new 
String[]{String.valueOf(paramsMap.get("margin").equals("rows")?0:1)};
                        }
-                               
+                       
                        vs[0] = stats.get( 
parts[1].substring(7).replaceAll(Lop.VARIABLE_NAME_PLACEHOLDER, "") );
                        vs[1] = _unknownStats; //TODO
                        vs[2] = stats.get( parts[parts.length-1] );
@@ -389,16 +376,14 @@ public abstract class CostEstimator
                        if( vs[2] == null ) //scalar output
                                vs[2] = _scalarStats;
                }
-               else if( inst instanceof MultiReturnBuiltinCPInstruction )
-               {
+               else if( inst instanceof MultiReturnBuiltinCPInstruction ) {
                        //applies to qr, lu, eigen (cost computation on input1)
                        MultiReturnBuiltinCPInstruction minst = 
(MultiReturnBuiltinCPInstruction) inst;
                        vs[0] = stats.get( minst.input1.getName() );
                        vs[1] = stats.get( minst.getOutput(0).getName() );
                        vs[2] = stats.get( minst.getOutput(1).getName() );
                }
-               else if( inst instanceof VariableCPInstruction )
-               {
+               else if( inst instanceof VariableCPInstruction ) {
                        setUnknownStats(vs);
                        
                        VariableCPInstruction varinst = (VariableCPInstruction) 
inst;
@@ -407,11 +392,10 @@ public abstract class CostEstimator
                                if( stats.containsKey( 
varinst.getInput1().getName() ) )
                                        vs[0] = stats.get( 
varinst.getInput1().getName() );     
                                attr = new 
String[]{varinst.getInput3().getName()};
-                       }       
+                       }
                }
-               else
-               {
-                       setUnknownStats(vs);            
+               else {
+                       setUnknownStats(vs);
                }
                
                //maintain var status (CP output always inmem)
@@ -426,7 +410,7 @@ public abstract class CostEstimator
        private static void setUnknownStats(VarStats[] vs) {
                vs[0] = _unknownStats;
                vs[1] = _unknownStats;
-               vs[2] = _unknownStats;  
+               vs[2] = _unknownStats;
        }
                
        private static long getNumIterations(HashMap<String,VarStats> stats, 
ForProgramBlock pb) {
diff --git 
a/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java 
b/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java
index e2a4a75..3f29132 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java
@@ -344,30 +344,30 @@ public class CostEstimatorStaticRuntime extends 
CostEstimator
                                                }
                                                return (leftSparse) ? xcm * 
(d1m * d1s + 1) : xcm * d1m;
                                        }
-                                   else if( optype.equals("uatrace") || 
optype.equals("uaktrace") )
-                                       return 2 * d1m * d1n;
-                                   else if( optype.equals("ua+") || 
optype.equals("uar+") || optype.equals("uac+")  ){
-                                       //sparse safe operations
-                                       if( !leftSparse ) //dense
-                                               return d1m * d1n;
-                                       else //sparse
-                                               return d1m * d1n * d1s;
-                                   }
-                                   else if( optype.equals("uak+") || 
optype.equals("uark+") || optype.equals("uack+"))
-                                       return 4 * d1m * d1n; //1*k+
-                                   else if( optype.equals("uasqk+") || 
optype.equals("uarsqk+") || optype.equals("uacsqk+"))
+                                       else if( optype.equals("uatrace") || 
optype.equals("uaktrace") )
+                                               return 2 * d1m * d1n;
+                                       else if( optype.equals("ua+") || 
optype.equals("uar+") || optype.equals("uac+")  ){
+                                               //sparse safe operations
+                                               if( !leftSparse ) //dense
+                                                       return d1m * d1n;
+                                               else //sparse
+                                                       return d1m * d1n * d1s;
+                                       }
+                                       else if( optype.equals("uak+") || 
optype.equals("uark+") || optype.equals("uack+"))
+                                               return 4 * d1m * d1n; //1*k+
+                                       else if( optype.equals("uasqk+") || 
optype.equals("uarsqk+") || optype.equals("uacsqk+"))
                                                return 5 * d1m * d1n; // +1 for 
multiplication to square term
-                                   else if( optype.equals("uamean") || 
optype.equals("uarmean") || optype.equals("uacmean"))
+                                       else if( optype.equals("uamean") || 
optype.equals("uarmean") || optype.equals("uacmean"))
                                                return 7 * d1m * d1n; //1*k+
-                                   else if( optype.equals("uavar") || 
optype.equals("uarvar") || optype.equals("uacvar"))
+                                       else if( optype.equals("uavar") || 
optype.equals("uarvar") || optype.equals("uacvar"))
                                                return 14 * d1m * d1n;
-                                   else if(   optype.equals("uamax") || 
optype.equals("uarmax") || optype.equals("uacmax")
-                                               || optype.equals("uamin") || 
optype.equals("uarmin") || optype.equals("uacmin")
-                                               || optype.equals("uarimax") || 
optype.equals("ua*") )
-                                       return d1m * d1n;
+                                       else if(   optype.equals("uamax") || 
optype.equals("uarmax") || optype.equals("uacmax")
+                                               || optype.equals("uamin") || 
optype.equals("uarmin") || optype.equals("uacmin")
+                                               || optype.equals("uarimax") || 
optype.equals("ua*") )
+                                               return d1m * d1n;
                                        
-                                   return 0;
-                                   
+                                       return 0;
+                               
                                case Binary: //opcodes: +, -, *, /, ^ (incl. 
^2, *2),
                                        //max, min, solve, ==, !=, <, >, <=, >= 
 
                                        //note: all relational ops are not 
sparsesafe
diff --git a/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java 
b/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
new file mode 100644
index 0000000..f4f8db4
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+/**
+ * Class storing execution cost estimates for federated executions with cost 
estimates split into different categories
+ * such as compute, read, and transfer cost.
+ */
+public class FederatedCost {
+       protected double _computeCost = 0;
+       protected double _readCost = 0;
+       protected double _inputTransferCost = 0;
+       protected double _outputTransferCost = 0;
+       protected double _inputTotalCost = 0;
+
+       public FederatedCost(){}
+
+       public FederatedCost(double readCost, double inputTransferCost, double 
outputTransferCost,
+               double computeCost, double inputTotalCost){
+               _readCost = readCost;
+               _inputTransferCost = inputTransferCost;
+               _outputTransferCost = outputTransferCost;
+               _computeCost = computeCost;
+               _inputTotalCost = inputTotalCost;
+       }
+
+       /**
+        * Get the total sum of costs stored in this object.
+        * @return total cost
+        */
+       public double getTotal(){
+               return _computeCost + _readCost + _inputTransferCost + 
_outputTransferCost + _inputTotalCost;
+       }
+
+       /**
+        * Multiply the input costs by the number of times the costs are 
repeated.
+        * @param repetitionNumber number of repetitions of the costs
+        */
+       public void addRepetitionCost(int repetitionNumber){
+               _inputTotalCost *= repetitionNumber;
+       }
+
+       /**
+        * Get summed input costs.
+        * @return summed input costs
+        */
+       public double getInputTotalCost(){
+               return _inputTotalCost;
+       }
+
+       public void setInputTotalCost(double inputTotalCost){
+               _inputTotalCost = inputTotalCost;
+       }
+
+       /**
+        * Add cost to the stored input cost.
+        * @param additionalCost to add to total input cost
+        */
+       public void addInputTotalCost(double additionalCost){
+               _inputTotalCost += additionalCost;
+       }
+
+       /**
+        * Add total of federatedCost to stored inputTotalCost.
+        * @param federatedCost input cost from which the total is retrieved
+        */
+       public void addInputTotalCost(FederatedCost federatedCost){
+               _inputTotalCost += federatedCost.getTotal();
+       }
+
+       /**
+        * Add costs of FederatedCost object to this object's current costs.
+        * @param additionalCost object to add to this object
+        */
+       public void addFederatedCost(FederatedCost additionalCost){
+               _readCost += additionalCost._readCost;
+               _inputTransferCost += additionalCost._inputTransferCost;
+               _outputTransferCost += additionalCost._outputTransferCost;
+               _computeCost += additionalCost._computeCost;
+               _inputTotalCost += additionalCost._inputTotalCost;
+       }
+
+       @Override
+       public String toString(){
+               StringBuilder builder = new StringBuilder();
+               builder.append(" computeCost: ");
+               builder.append(_computeCost);
+               builder.append("\n readCost: ");
+               builder.append(_readCost);
+               builder.append("\n inputTransferCost: ");
+               builder.append(_inputTransferCost);
+               builder.append("\n outputTransferCost: ");
+               builder.append(_outputTransferCost);
+               builder.append("\n inputTotalCost: ");
+               builder.append(_inputTotalCost);
+               builder.append("\n total cost: ");
+               builder.append(getTotal());
+               return builder.toString();
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java 
b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
new file mode 100644
index 0000000..3e2f994
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+
+import java.util.ArrayList;
+
+/**
+ * Cost estimator for federated executions with methods and constants for 
going through DML programs to estimate costs.
+ */
+public class FederatedCostEstimator {
+       public int DEFAULT_MEMORY_ESTIMATE = 8;
+       public int DEFAULT_ITERATION_NUMBER = 15;
+       public double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 1024*1024*1024; 
//Default network bandwidth in bytes per second
+       public double WORKER_COMPUTE_BANDWITH_FLOPS = 2.5*1024*1024*1024; 
//Default compute bandwidth in FLOPS
+       public double WORKER_DEGREE_OF_PARALLELISM = 8; //Default number of 
parallel processes for workers
+       public double WORKER_READ_BANDWIDTH_BYTES_PS = 3.5*1024*1024*1024; 
//Default read bandwidth in bytes per second
+
+       public boolean printCosts = false; //Temporary for debugging purposes
+
+       /**
+        * Estimate cost of given DML program in bytes.
+        * @param dmlProgram for which the cost is estimated
+        * @return federated cost object with cost estimate in bytes
+        */
+       public FederatedCost costEstimate(DMLProgram dmlProgram){
+               FederatedCost programTotalCost = new FederatedCost();
+               for ( StatementBlock stmBlock : dmlProgram.getStatementBlocks() 
)
+                       
programTotalCost.addInputTotalCost(costEstimate(stmBlock).getTotal());
+               return programTotalCost;
+       }
+
+       /**
+        * Cost estimate in bytes of given statement block.
+        * @param sb statement block
+        * @return federated cost object with cost estimate in bytes
+        */
+       private FederatedCost costEstimate(StatementBlock sb){
+               if ( sb instanceof WhileStatementBlock){
+                       WhileStatementBlock whileSB = (WhileStatementBlock) sb;
+                       FederatedCost whileSBCost = 
costEstimate(whileSB.getPredicateHops());
+                       for ( Statement statement : whileSB.getStatements() ){
+                               WhileStatement whileStatement = 
(WhileStatement) statement;
+                               for ( StatementBlock bodyBlock : 
whileStatement.getBody() )
+                                       
whileSBCost.addInputTotalCost(costEstimate(bodyBlock));
+                       }
+                       whileSBCost.addRepetitionCost(DEFAULT_ITERATION_NUMBER);
+                       return whileSBCost;
+               }
+               else if ( sb instanceof IfStatementBlock){
+                       //Get cost of if-block + else-block and divide by two
+                       // since only one of the code blocks will be executed 
in the end
+                       IfStatementBlock ifSB = (IfStatementBlock) sb;
+                       FederatedCost ifSBCost = new FederatedCost();
+                       for ( Statement statement : ifSB.getStatements() ){
+                               IfStatement ifStatement = (IfStatement) 
statement;
+                               for ( StatementBlock ifBodySB : 
ifStatement.getIfBody() )
+                                       
ifSBCost.addInputTotalCost(costEstimate(ifBodySB));
+                               for ( StatementBlock elseBodySB : 
ifStatement.getElseBody() )
+                                       
ifSBCost.addInputTotalCost(costEstimate(elseBodySB));
+                       }
+                       
ifSBCost.setInputTotalCost(ifSBCost.getInputTotalCost()/2);
+                       
ifSBCost.addInputTotalCost(costEstimate(ifSB.getPredicateHops()));
+                       return ifSBCost;
+               }
+               else if ( sb instanceof ForStatementBlock){
+                       // This also includes ParForStatementBlocks
+                       ForStatementBlock forSB = (ForStatementBlock) sb;
+                       ArrayList<Hop> predicateHops = new ArrayList<>();
+                       predicateHops.add(forSB.getFromHops());
+                       predicateHops.add(forSB.getToHops());
+                       predicateHops.add(forSB.getIncrementHops());
+                       FederatedCost forSBCost = costEstimate(predicateHops);
+                       for ( Statement statement : forSB.getStatements() ){
+                               ForStatement forStatement = (ForStatement) 
statement;
+                               for ( StatementBlock forStatementBlockBody : 
forStatement.getBody() )
+                                       
forSBCost.addInputTotalCost(costEstimate(forStatementBlockBody));
+                       }
+                       forSBCost.addRepetitionCost(forSB.getEstimateReps());
+                       return forSBCost;
+               }
+               else if ( sb instanceof FunctionStatementBlock){
+                       FederatedCost funcCost = addInitialInputCost(sb);
+                       FunctionStatementBlock funcSB = 
(FunctionStatementBlock) sb;
+                       for(Statement statement : funcSB.getStatements()) {
+                               FunctionStatement funcStatement = 
(FunctionStatement) statement;
+                               for ( StatementBlock funcStatementBody : 
funcStatement.getBody() )
+                                       
funcCost.addInputTotalCost(costEstimate(funcStatementBody));
+                       }
+                       return funcCost;
+               }
+               else {
+                       // StatementBlock type (no subclass)
+                       return costEstimate(sb.getHops());
+               }
+       }
+
+       /**
+        * Creates new FederatedCost object and adds all child statement block 
cost estimates to the object.
+        * @param sb statement block
+        * @return new FederatedCost estimate object with all estimates of 
child statement blocks added
+        */
+       private FederatedCost addInitialInputCost(StatementBlock sb){
+               FederatedCost basicCost = new FederatedCost();
+               for ( StatementBlock childSB : 
sb.getDMLProg().getStatementBlocks() )
+                       
basicCost.addInputTotalCost(costEstimate(childSB).getTotal());
+               return basicCost;
+       }
+
+       /**
+        * Cost estimate in bytes of given list of roots.
+        * The individual cost estimates of the hops are summed.
+        * @param roots list of hops
+        * @return new FederatedCost object with sum of cost estimates of given 
hops
+        */
+       private FederatedCost costEstimate(ArrayList<Hop> roots){
+               FederatedCost basicCost = new FederatedCost();
+               for ( Hop root : roots )
+                       basicCost.addInputTotalCost(costEstimate(root));
+               return basicCost;
+       }
+
+       /**
+        * Return cost estimate in bytes of Hop DAG starting from given root.
+        * @param root of Hop DAG for which cost is estimated
+        * @return cost estimation of Hop DAG starting from given root
+        */
+       private FederatedCost costEstimate(Hop root){
+               if ( root.federatedCostInitialized() )
+                       return root.getFederatedCost();
+               else {
+                       // If no input has FOUT, the root will be processed by 
the coordinator
+                       boolean hasFederatedInput = root.someInputFederated();
+                       //the input cost is included the first time the input 
hop is used
+                       //for additional usage, the additional cost is zero 
(disregarding potential read cost)
+                       double inputCosts = root.getInput().stream()
+                               .mapToDouble( in -> 
in.federatedCostInitialized() ? 0 : costEstimate(in).getTotal() )
+                               .sum();
+                       double inputTransferCost = hasFederatedInput ? 
root.getInput().stream()
+                               .filter(Hop::hasLocalOutput)
+                               .mapToDouble(in -> 
in.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+                               .map(inMem -> inMem/ 
WORKER_NETWORK_BANDWIDTH_BYTES_PS)
+                               .sum() : 0;
+                       double computingCost = 
ComputeCost.getHOPComputeCost(root);
+                       if ( hasFederatedInput ){
+                               //Find the number of inputs that has FOUT set.
+                               int numWorkers = 
(int)root.getInput().stream().filter(Hop::hasFederatedOutput).count();
+                               //divide memory usage by the number of workers 
the computation would be split to multiplied by
+                               //the number of parallel processes at each 
worker multiplied by the FLOPS of each process
+                               //This assumes uniform workload among the 
workers with FOUT data involved in the operation
+                               //and assumes that the degree of parallelism 
and compute bandwidth are equal for all workers
+                               computingCost = computingCost / 
(numWorkers*WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+                       } else computingCost = computingCost / 
(WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+                       //Calculate output transfer cost if the operation is 
computed at federated workers and the output is forced to the coordinator
+                       double outputTransferCost = ( root.hasLocalOutput() && 
hasFederatedInput ) ?
+                               
root.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
+                       double readCost = 
root.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_READ_BANDWIDTH_BYTES_PS;
+
+                       FederatedCost rootFedCost =
+                               new FederatedCost(readCost, inputTransferCost, 
outputTransferCost, computingCost, inputCosts);
+                       root.setFederatedCost(rootFedCost);
+
+                       if ( printCosts )
+                               printCosts(root);
+
+                       return rootFedCost;
+               }
+       }
+
+       /**
+        * Prints costs and information about root for debugging purposes
+        * @param root hop for which information is printed
+        */
+       private static void printCosts(Hop root){
+               System.out.println("===============================");
+               System.out.println(root);
+               System.out.println("Is federated: " + root.isFederated());
+               System.out.println("Has federated output: " + 
root.hasFederatedOutput());
+               System.out.println(root.getText());
+               System.out.println("Pure computeCost: " + 
ComputeCost.getHOPComputeCost(root));
+               System.out.println("Dim1: " + root.getDim1() + " Dim2: " + 
root.getDim2());
+               System.out.println(root.getFederatedCost().toString());
+               System.out.println("===============================");
+       }
+}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 2e3edb0..04cdf32 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -140,6 +140,7 @@ public class ProgramRewriter
                        }
                        if ( OptimizerUtils.FEDERATED_COMPILATION ) {
                                _dagRuleSet.add( new 
RewriteFederatedExecution() );
+                               _sbRuleSet.add( new 
RewriteFederatedStatementBlocks() );
                        }
                }
                
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
index 29cda4a..e6a92ce 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
@@ -23,9 +23,14 @@ import org.apache.commons.lang3.tuple.Pair;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
 import org.apache.sysds.hops.DataOp;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
 import org.apache.sysds.parser.DataExpression;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -36,6 +41,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
 import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
 import org.apache.sysds.runtime.lineage.LineageItem;
@@ -52,6 +58,9 @@ import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
+import java.util.Collections;
+import java.util.EnumMap;
+import java.util.Map;
 import java.util.concurrent.Future;
 
 public class RewriteFederatedExecution extends HopRewriteRule {
@@ -61,18 +70,115 @@ public class RewriteFederatedExecution extends 
HopRewriteRule {
                        return null;
                for ( Hop root : roots )
                        visitHop(root);
+
+               return selectFederatedExecutionPlan(roots);
+       }
+
+       @Override public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus 
state) {
+               return null;
+       }
+
+       /**
+        * Select federated execution plan for every Hop in the DAG starting 
from given roots.
+        * @param roots starting point for going through the Hop DAG to update 
the FederatedOutput fields.
+        * @return the list of roots with updated FederatedOutput fields.
+        */
+       private static ArrayList<Hop> 
selectFederatedExecutionPlan(ArrayList<Hop> roots){
+               for (Hop root : roots){
+                       root.resetVisitStatus();
+               }
+               for ( Hop root : roots ){
+                       visitFedPlanHop(root);
+               }
                return roots;
        }
 
-       @Override
-       public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
-               if( root == null )
-                       return null;
-               visitHop(root);
-               return root;
+       /**
+        * Go through the Hop DAG and set the FederatedOutput field for each 
Hop from leaf to given currentHop.
+        * @param currentHop the Hop from which the DAG is visited
+        */
+       private static void visitFedPlanHop(Hop currentHop){
+               if ( currentHop.isVisited() )
+                       return;
+               if ( currentHop.getInput() != null && 
currentHop.getInput().size() > 0 && !isFederatedDataOp(currentHop) ){
+                       // Depth first to get to the input
+                       for ( Hop input : currentHop.getInput() )
+                               visitFedPlanHop(input);
+               } else if ( isFederatedDataOp(currentHop) ) {
+                       // leaf federated node
+                       //TODO: This will block the cases where the federated 
DataOp is based on input that are also federated.
+                       // This means that the actual federated leaf nodes will 
never be reached.
+                       currentHop.setFederatedOutput(FederatedOutput.FOUT);
+               }
+               if ( ( isFedInstSupportedHop(currentHop) ) ){
+                       // The Hop can be FOUT or LOUT or None. Check utility 
of FOUT vs LOUT vs None.
+                       
currentHop.setFederatedOutput(getHighestUtilFedOut(currentHop));
+               }
+               else
+                       
currentHop.setFederatedOutput(FEDInstruction.FederatedOutput.NONE);
+               currentHop.setVisited();
+       }
+
+       /**
+        * Returns the FederatedOutput with the highest utility out of the 
valid FederatedOutput values.
+        * @param hop for which the utility is found
+        * @return the FederatedOutput value with highest utility for the given 
Hop
+        */
+       private static FederatedOutput getHighestUtilFedOut(Hop hop){
+               Map<FederatedOutput,Long> fedOutUtilMap = new 
EnumMap<>(FederatedOutput.class);
+               if ( isFOUTSupported(hop) )
+                       fedOutUtilMap.put(FederatedOutput.FOUT, getUtilFout());
+               if ( hop.getPrivacy() == null || (hop.getPrivacy() != null && 
!hop.getPrivacy().hasConstraints()) )
+                       fedOutUtilMap.put(FederatedOutput.LOUT, 
getUtilLout(hop));
+               fedOutUtilMap.put(FederatedOutput.NONE, 0L);
+
+               Map.Entry<FederatedOutput, Long> fedOutMax = 
Collections.max(fedOutUtilMap.entrySet(), Map.Entry.comparingByValue());
+               return fedOutMax.getKey();
        }
-       
-       private void visitHop(Hop hop){
+
+       /**
+        * Utility if hop is FOUT. This is a simple version where it always 
returns 1.
+        * @return utility if hop is FOUT
+        */
+       private static long getUtilFout(){
+               //TODO: Make better utility estimation
+               return 1;
+       }
+
+       /**
+        * Utility if hop is LOUT. This is a simple version only based on 
dimensions.
+        * @param hop for which utility is calculated
+        * @return utility if hop is LOUT
+        */
+       private static long getUtilLout(Hop hop){
+               //TODO: Make better utility estimation
+               return -(long)hop.getMemEstimate();
+       }
+
+       private static boolean isFedInstSupportedHop(Hop hop){
+
+               // Check that some input is FOUT, otherwise none of the fed 
instructions will run unless it is fedinit
+               if ( (!isFederatedDataOp(hop)) && 
hop.getInput().stream().noneMatch(Hop::hasFederatedOutput) )
+                       return false;
+
+               // The following operations are supported given that the above 
conditions have not returned already
+               return ( hop instanceof AggBinaryOp || hop instanceof BinaryOp 
|| hop instanceof ReorgOp
+                       || hop instanceof AggUnaryOp || hop instanceof 
TernaryOp || hop instanceof DataOp );
+       }
+
+       /**
+        * Checks to see if the associatedHop supports FOUT.
+        * @param associatedHop for which FOUT support is checked
+        * @return true if FOUT is supported by the associatedHop
+        */
+       private static boolean isFOUTSupported(Hop associatedHop){
+               // If the output of AggUnaryOp is a scalar, the operation 
cannot be FOUT
+               if ( associatedHop instanceof AggUnaryOp )
+                       return !associatedHop.isScalar();
+               return true;
+       }
+
+       private static void visitHop(Hop hop){
                if (hop.isVisited())
                        return;
 
@@ -84,15 +190,6 @@ public class RewriteFederatedExecution extends 
HopRewriteRule {
                hop.setVisited();
        }
 
-       private static void privacyBasedHopDecision(Hop hop){
-               PrivacyPropagator.hopPropagation(hop);
-               PrivacyConstraint privacyConstraint = hop.getPrivacy();
-               if ( privacyConstraint != null && 
privacyConstraint.hasConstraints() )
-                       
hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT);
-               else if ( hop.someInputFederated() )
-                       
hop.setFederatedOutput(FEDInstruction.FederatedOutput.LOUT);
-       }
-
        /**
         * Get privacy constraints of DataOps from federated worker,
         * propagate privacy constraints from input to current hop,
@@ -101,7 +198,7 @@ public class RewriteFederatedExecution extends 
HopRewriteRule {
         */
        private static void privacyBasedHopDecisionWithFedCall(Hop hop){
                loadFederatedPrivacyConstraints(hop);
-               privacyBasedHopDecision(hop);
+               PrivacyPropagator.hopPropagation(hop);
        }
 
        /**
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
new file mode 100644
index 0000000..18b36d5
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.Arrays;
+import java.util.List;
+
+public class RewriteFederatedStatementBlocks extends StatementBlockRewriteRule 
{
+
+       /**
+        * Indicates if the rewrite potentially splits dags, which is used
+        * for phase ordering of rewrites.
+        *
+        * @return true if dag splits are possible.
+        */
+       @Override public boolean createsSplitDag() {
+               return false;
+       }
+
+       /**
+        * Handle an arbitrary statement block. Specific type constraints have 
to be ensured
+        * within the individual rewrites. If a rewrite does not apply to 
individual blocks, it
+        * should simply return the input block.
+        *
+        * @param sb    statement block
+        * @param state program rewrite status
+        * @return list of statement blocks
+        */
+       @Override
+       public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, 
ProgramRewriteStatus state) {
+               return Arrays.asList(sb);
+       }
+
+       /**
+        * Handle a list of statement blocks. Specific type constraints have to 
be ensured
+        * within the individual rewrites. If a rewrite does not require 
sequence access, it
+        * should simply return the input list of statement blocks.
+        *
+        * @param sbs   list of statement blocks
+        * @param state program rewrite status
+        * @return list of statement blocks
+        */
+       @Override
+       public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> 
sbs, ProgramRewriteStatus state) {
+               return sbs;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index bae38a2..755287a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -64,6 +64,7 @@ public class FEDInstructionParser extends InstructionParser
                String2FEDInstructionType.put( "r'"     , FEDType.Reorg );
                String2FEDInstructionType.put( "rdiag"  , FEDType.Reorg );
                String2FEDInstructionType.put( "rshape" , FEDType.Reorg );
+               String2FEDInstructionType.put( "rev"    , FEDType.Reorg );
 
                // Ternary Instruction Opcodes
                String2FEDInstructionType.put( "+*" , FEDType.Ternary);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index fb0647e..2e5366e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
@@ -124,9 +125,60 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
                        new CPOperand[]{input1}, new 
long[]{in.getFedMapping().getID()}, true);
                map.execute(getTID(), fr1);
 
-               // derive new fed mapping for output
                MatrixObject out = ec.getMatrixObject(output);
-               
out.setFedMapping(in.getFedMapping().copyWithNewID(fr1.getID()));
+               deriveNewOutputFedMapping(in, out, fr1);
+       }
+
+       /**
+        * Set output fed mapping based on federated partitioning and 
aggregation type.
+        * @param in matrix object from which fed partitioning originates from
+        * @param out matrix object holding the dimensions of the instruction 
output
+        * @param fr1 federated request holding the instruction execution call
+        */
+       private void deriveNewOutputFedMapping(MatrixObject in, MatrixObject 
out, FederatedRequest fr1){
+               //Get agg type
+               if ( !(instOpcode.equals("uack+") || 
instOpcode.equals("uark+")) )
+                       throw new DMLRuntimeException("Operation " + instOpcode 
+ " is unknown to FOUT processing");
+               boolean isColAgg = instOpcode.equals("uack+");
+               //Get partition type
+               FederationMap.FType inFtype = in.getFedMapping().getType();
+               //Get fedmap from in
+               FederationMap inputFedMapCopy = 
in.getFedMapping().copyWithNewID(fr1.getID());
+
+               //if partition type is row and aggregation type is row
+               //   then get row dim split from input and use as row dimension 
and get col dimension from output col dimension
+               //   and set FType to ROW
+               if ( inFtype.isRowPartitioned() && !isColAgg ){
+                       for ( FederatedRange range : 
inputFedMapCopy.getFederatedRanges() )
+                               range.setEndDim(1,out.getNumColumns());
+                       inputFedMapCopy.setType(FederationMap.FType.ROW);
+               }
+               //if partition type is row and aggregation type is col
+               //   then get row and col dimension from out and use those 
dimensions for both federated workers
+               //   and set FType to PART
+               //if partition type is col and aggregation type is row
+               //   then set row and col dimension from out and use those 
dimensions for both federated workers
+               //   and set FType to PART
+               if ( (inFtype.isRowPartitioned() && isColAgg) || 
(inFtype.isColPartitioned() && !isColAgg) ){
+                       for ( FederatedRange range : 
inputFedMapCopy.getFederatedRanges() ){
+                               range.setBeginDim(0,0);
+                               range.setBeginDim(1,0);
+                               range.setEndDim(0,out.getNumRows());
+                               range.setEndDim(1,out.getNumColumns());
+                       }
+                       inputFedMapCopy.setType(FederationMap.FType.PART);
+               }
+               //if partition type is col and aggregation type is col
+               //   then set row dimension to output and col dimension to in 
col split
+               //   and set FType to COL
+               if ( inFtype.isColPartitioned() && isColAgg ){
+                       for ( FederatedRange range : 
inputFedMapCopy.getFederatedRanges() )
+                               range.setEndDim(0,out.getNumRows());
+                       inputFedMapCopy.setType(FederationMap.FType.COL);
+               }
+
+               //set out fedmap in the end
+               out.setFedMapping(inputFedMapCopy);
        }
 
        /**
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index 825b984..c2e7ab1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -130,8 +130,9 @@ public class AppendFEDInstruction extends 
BinaryFEDInstruction {
                }
                else {
                        throw new DMLRuntimeException("Unsupported federated 
append: "
-                               + (mo1.isFederated() ? 
mo1.getFedMapping().getType().name():"LOCAL") + " "
-                               + (mo2.isFederated() ? 
mo2.getFedMapping().getType().name():"LOCAL") + " " + _cbind);
+                               + " input 1 FType is " + (mo1.isFederated() ? 
mo1.getFedMapping().getType().name():"LOCAL")
+                               + ", input 2 FType is " + (mo2.isFederated() ? 
mo2.getFedMapping().getType().name():"LOCAL")
+                               + ", and column bind is " + _cbind);
                }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index 2a1cbb1..8795308 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -217,10 +217,10 @@ public class CtableFEDInstruction extends 
ComputationFEDInstruction {
         * @param mo2 input matrix object mo2
         * @return boolean indicating if the output can be kept on the 
federated sites
         */
-       private boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) {
+       private static boolean isFedOutput(FederationMap fedMap, MatrixObject 
mo2) {
                MatrixBlock mb = mo2.acquireReadAndRelease();
                FederatedRange[] fedRanges = fedMap.getFederatedRanges(); // 
federated ranges of mo1
-               SortedMap<Double, Double> fedDims = new TreeMap<Double, 
Double>(); // <beginDim, endDim>
+               SortedMap<Double, Double> fedDims = new TreeMap<>(); // 
<beginDim, endDim>
 
                // collect min and max of the corresponding slices of mo2
                IntStream.range(0, fedRanges.length).forEach(i -> {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index f35030f..0e00faa 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -60,6 +60,9 @@ public abstract class FEDInstruction extends Instruction {
                public boolean isForcedLocal() {
                        return this == LOUT;
                }
+               public boolean isForced(){
+                       return this == FOUT || this == LOUT;
+               }
        }
 
        protected final FEDType _fedType;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index 19847a2..cb3074f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -24,6 +24,7 @@ import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.Future;
 import java.util.stream.Stream;
 
 import org.apache.commons.lang3.tuple.Pair;
@@ -52,7 +53,8 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 
 public class ReorgFEDInstruction extends UnaryFEDInstruction {
-       
+       private static boolean fedoutFlagInString = false;
+
        public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, 
String opcode, String istr, FederatedOutput fedOut) {
                super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut);
        }
@@ -80,6 +82,7 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
                        return new ReorgFEDInstruction(new 
ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
                }
                else if ( opcode.equalsIgnoreCase("rev") ) {
+                       fedoutFlagInString = parts.length > 3;
                        parseUnaryInstruction(str, in, out); //max 2 operands
                        return new ReorgFEDInstruction(new 
ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
                }
@@ -96,24 +99,34 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
                if( !mo1.isFederated() )
                        throw new DMLRuntimeException("Federated Reorg: "
                                + "Federated input expected, but invoked w/ 
"+mo1.isFederated());
+               if ( !( mo1.isFederated(FederationMap.FType.COL) || 
mo1.isFederated(FederationMap.FType.ROW)) )
+                       throw new DMLRuntimeException("Federation type " + 
mo1.getFedMapping().getType()
+                               + " is not supported for Reorg processing");
 
                if(instOpcode.equals("r'")) {
                        //execute transpose at federated site
                        FederatedRequest fr1 = 
FederationUtils.callInstruction(instString,
                                output, new CPOperand[] {input1},
                                new long[] {mo1.getFedMapping().getID()}, true);
-                       mo1.getFedMapping().execute(getTID(), true, fr1);
+                       if (_fedOut != null && !_fedOut.isForcedLocal()){
+                               mo1.getFedMapping().execute(getTID(), true, 
fr1);
 
-                       //drive output federated mapping
-                       MatrixObject out = ec.getMatrixObject(output);
-                       out.getDataCharacteristics().set(mo1.getNumColumns(), 
mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz());
-                       
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose());
+                               //drive output federated mapping
+                               MatrixObject out = ec.getMatrixObject(output);
+                               
out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), (int) 
mo1.getBlocksize(), mo1.getNnz());
+                               
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose());
+                       } else {
+                               FederatedRequest getRequest = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
+                               Future<FederatedResponse>[] execResponse = 
mo1.getFedMapping().execute(getTID(), true, fr1, getRequest);
+                               ec.setMatrixOutput(output.getName(),
+                                       FederationUtils.bind(execResponse, 
mo1.isFederated(FederationMap.FType.COL)));
+                       }
                }
                else if(instOpcode.equalsIgnoreCase("rev")) {
                        //execute transpose at federated site
                        FederatedRequest fr1 = 
FederationUtils.callInstruction(instString,
                                output, new CPOperand[] {input1},
-                               new long[] {mo1.getFedMapping().getID()}, true);
+                               new long[] {mo1.getFedMapping().getID()}, 
fedoutFlagInString);
                        mo1.getFedMapping().execute(getTID(), true, fr1);
 
                        if(mo1.isFederated(FederationMap.FType.ROW))
@@ -123,6 +136,11 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
                        MatrixObject out = ec.getMatrixObject(output);
                        out.getDataCharacteristics().set(mo1.getNumRows(), 
mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz());
                        
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+
+                       if ( _fedOut != null && _fedOut.isForcedLocal() ){
+                               out.acquireReadAndRelease();
+                               out.getFedMapping().cleanup(getTID(), 
fr1.getID());
+                       }
                }
                else if (instOpcode.equals("rdiag")) {
                        RdiagResult result;
@@ -158,6 +176,10 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
                                .set(diagFedMap.getMaxIndexInRange(0), 
diagFedMap.getMaxIndexInRange(1),
                                        (int) mo1.getBlocksize());
                        rdiag.setFedMapping(diagFedMap);
+                       if ( _fedOut != null && _fedOut.isForcedLocal() ){
+                               rdiag.acquireReadAndRelease();
+                               rdiag.getFedMapping().cleanup(getTID(), 
rdiag.getFedMapping().getID());
+                       }
                }
        }
 
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java 
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 3d96cd9..3ac462c 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -36,6 +36,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
+import java.util.Set;
 import java.util.stream.Collectors;
 
 import org.apache.commons.io.FileUtils;
@@ -2097,6 +2098,31 @@ public abstract class AutomatedTestBase {
                return false;
        }
 
+       /**
+        * Checks if given strings are all in the set of heavy hitters.
+        * @param str opcodes for which it is checked if all are in the heavy 
hitters
+        * @return true if all given strings are in the set of heavy hitters
+        */
+       protected boolean heavyHittersContainsAllString(String... str){
+               Set<String> heavyHitters = Statistics.getCPHeavyHitterOpCodes();
+               return Arrays.stream(str).allMatch(heavyHitters::contains);
+       }
+
+       /**
+        * Returns an array of the given opcodes which are not present in the 
set of heavy hitter opcodes.
+        * @param opcodes for which it is checked if they are among the heavy 
hitters
+        * @return array of opcodes not found in heavy hitters
+        */
+       protected String[] missingHeavyHitters(String... opcodes){
+               Set<String> heavyHitters = Statistics.getCPHeavyHitterOpCodes();
+               List<String> missingHeavyHitters = new ArrayList<>();
+               for (String opcode : opcodes){
+                       if ( !heavyHitters.contains(opcode) )
+                               missingHeavyHitters.add(opcode);
+               }
+               return missingHeavyHitters.toArray(new String[0]);
+       }
+
        protected boolean heavyHittersContainsString(String str, int minCount) {
                int count = 0;
                for(String opcode : Statistics.getCPHeavyHitterOpCodes())
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
new file mode 100644
index 0000000..0092a3a
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy.fedplanning;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.NaryOp;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.cost.FederatedCost;
+import org.apache.sysds.hops.cost.FederatedCostEstimator;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DMLTranslator;
+import org.apache.sysds.parser.LanguageException;
+import org.apache.sysds.parser.ParserFactory;
+import org.apache.sysds.parser.ParserWrapper;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Set;
+
+import static org.apache.sysds.common.Types.OpOp2.MULT;
+
+public class FederatedCostEstimatorTest extends AutomatedTestBase {
+
+       private static final String TEST_DIR = "functions/privacy/fedplanning/";
+       private static final String HOME = SCRIPT_DIR + TEST_DIR;
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedCostEstimatorTest.class.getSimpleName() + "/";
+       FederatedCostEstimator fedCostEstimator = new FederatedCostEstimator();
+
+       @Override
+       public void setUp() {}
+
+       @Test
+       public void simpleBinary() {
+               fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+               fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+
+               /*
+                * HOP                  Occurences              ComputeCost     
        ReadCost        ComputeCostFinal        ReadCostFinal
+                * 
------------------------------------------------------------------------------------------
+                * LiteralOp    16                              1               
                0                       0.0625                          0
+                * DataGenOp    2                               100             
                64                      6.25                            6.4
+                * BinaryOp             1                               100     
                        1600            6.25                            160
+                * TOSTRING             1                               1       
                        800                     0.0625                          
80
+                * UnaryOp              1                               1       
                        8                       0.0625                          
0.8
+                */
+               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+
+               double expectedCost = computeCost + readCost;
+               runTest("BinaryCostEstimatorTest.dml", false, expectedCost);
+       }
+
+       @Test
+       public void ifElseTest(){
+               fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+               fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+               double expectedCost = ((computeCost + readCost + 0.8 + 0.0625 + 
0.0625) / 2) + 0.0625 + 0.8 + 0.0625;
+               runTest("IfElseCostEstimatorTest.dml", false, expectedCost);
+       }
+
+       @Test
+       public void whileTest(){
+               fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+               fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+               double expectedCost = (computeCost + readCost + 0.0625) * 
fedCostEstimator.DEFAULT_ITERATION_NUMBER + 0.0625 + 0.8;
+               runTest("WhileCostEstimatorTest.dml", false, expectedCost);
+       }
+
+       @Test
+       public void forLoopTest(){
+               fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+               fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+               double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 
0.0625 + 0.0625 + 0.8 + 0.0625;
+               double expectedCost = (computeCost + readCost + predicateCost) 
* 5;
+               runTest("ForLoopCostEstimatorTest.dml", false, expectedCost);
+       }
+
+       @Test
+       public void parForLoopTest(){
+               fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+               fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+               double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 
0.0625 + 0.0625 + 0.8 + 0.0625;
+               double expectedCost = (computeCost + readCost + predicateCost) 
* 5;
+               runTest("ParForLoopCostEstimatorTest.dml", false, expectedCost);
+       }
+
+       @Test
+       public void functionTest(){
+               fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+               fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+               double expectedCost = (computeCost + readCost);
+               runTest("FunctionCostEstimatorTest.dml", false, expectedCost);
+       }
+
+       @Test
+       public void federatedMultiply() {
+               fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+               fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+               fedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;
+
+               double literalOpCost = 10*0.0625;
+               double naryOpCostSpecial = (0.125+2.2);
+               double naryOpCostSpecial2 = (0.25+6.4);
+               double naryOpCost = 4*(0.125+1.6);
+               double reorgOpCost = 6250+80015.2+160030.4;
+               double binaryOpMultCost = 3125+160000;
+               double aggBinaryOpCost = 125000+160015.2+160030.4+190.4;
+               double dataOpCost = 2*(6250+5.6);
+               double dataOpWriteCost = 6.25+100.3;
+
+               double expectedCost = literalOpCost + naryOpCost + 
naryOpCostSpecial + naryOpCostSpecial2 + reorgOpCost
+                       + binaryOpMultCost + aggBinaryOpCost + dataOpCost + 
dataOpWriteCost;
+               runTest("FederatedMultiplyCostEstimatorTest.dml", false, 
expectedCost);
+
+               double aggBinaryActualCost = hops.stream()
+                       .filter(hop -> hop instanceof AggBinaryOp)
+                       .mapToDouble(aggHop -> 
aggHop.getFederatedCost().getTotal()-aggHop.getFederatedCost().getInputTotalCost())
+                       .sum();
+               Assert.assertEquals(aggBinaryOpCost, aggBinaryActualCost, 
0.0001);
+
+               double writeActualCost = hops.stream()
+                       .filter(hop -> hop instanceof DataOp)
+                       .mapToDouble(writeHop -> 
writeHop.getFederatedCost().getTotal()-writeHop.getFederatedCost().getInputTotalCost())
+                       .sum();
+               Assert.assertEquals(dataOpWriteCost+dataOpCost, 
writeActualCost, 0.0001);
+       }
+
+       Set<Hop> hops = new HashSet<>();
+
+       /**
+        * Recursively adds the hop and its inputs to the set of hops.
+        * @param hop root to be added to set of hops
+        */
+       private void addHop(Hop hop){
+               hops.add(hop);
+               for(Hop inHop : hop.getInput()){
+                       addHop(inHop);
+               }
+       }
+
+       /**
+        * Sets dimensions of federated X and Y and sets binary multiplication 
to FOUT.
+        * @param prog dml program where the HOPS are modified
+        */
+       private void modifyFedouts(DMLProgram prog){
+               prog.getStatementBlocks().forEach(stmBlock -> 
stmBlock.getHops().forEach(this::addHop));
+               hops.forEach(hop -> {
+                       if ( hop instanceof DataOp || (hop instanceof BinaryOp 
&& ((BinaryOp) hop).getOp() == MULT ) ){
+                               
hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT);
+                               hop.setExecType(Types.ExecType.FED);
+                       } else {
+                               
hop.setFederatedOutput(FEDInstruction.FederatedOutput.LOUT);
+                       }
+                       if ( hop.getOpString().equals("Fed Y") || 
hop.getOpString().equals("Fed X") ){
+                               hop.setDim1(10000);
+                               hop.setDim2(10);
+                       }
+               });
+       }
+
+       @SuppressWarnings("unused")
+       private void printHopsInfo(){
+               //LiteralOp
+               long literalCount = hops.stream().filter(hop -> hop instanceof 
LiteralOp).count();
+               System.out.println("LiteralOp Count: " + literalCount);
+               //NaryOp
+               long naryCount = hops.stream().filter(hop -> hop instanceof 
NaryOp).count();
+               System.out.println("NaryOp Count " + naryCount);
+               //ReorgOp
+               long reorgCount = hops.stream().filter(hop -> hop instanceof 
ReorgOp).count();
+               System.out.println("ReorgOp Count: " + reorgCount);
+               //BinaryOp
+               long binaryCount = hops.stream().filter(hop -> hop instanceof 
BinaryOp).count();
+               System.out.println("Binary count: " + binaryCount);
+               //AggBinaryOp
+               long aggBinaryCount = hops.stream().filter(hop -> hop 
instanceof AggBinaryOp).count();
+               System.out.println("AggBinaryOp Count: " + aggBinaryCount);
+               //DataOp
+               long dataOpCount = hops.stream().filter(hop -> hop instanceof 
DataOp).count();
+               System.out.println("DataOp Count: " + dataOpCount);
+
+               
hops.stream().map(Hop::getClass).distinct().forEach(System.out::println);
+       }
+
+       private void runTest( String scriptFilename, boolean expectedException, 
double expectedCost ) {
+               boolean raisedException = false;
+               try
+               {
+                       setTestConfig(scriptFilename);
+                       String dmlScriptString = readScript(scriptFilename);
+
+                       //parsing, dependency analysis and constructing hops 
(step 3 and 4 in DMLScript.java)
+                       ParserWrapper parser = ParserFactory.createParser();
+                       DMLProgram prog = 
parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new 
HashMap<>());
+                       DMLTranslator dmlt = new DMLTranslator(prog);
+                       dmlt.liveVariableAnalysis(prog);
+                       dmlt.validateParseTree(prog);
+                       dmlt.constructHops(prog);
+                       if ( 
scriptFilename.equals("FederatedMultiplyCostEstimatorTest.dml")){
+                               modifyFedouts(prog);
+                               dmlt.rewriteHopsDAG(prog);
+                               hops = new HashSet<>();
+                               prog.getStatementBlocks().forEach(stmBlock -> 
stmBlock.getHops().forEach(this::addHop));
+                       }
+
+                       FederatedCost actualCost = 
fedCostEstimator.costEstimate(prog);
+                       Assert.assertEquals(expectedCost, 
actualCost.getTotal(), 0.0001);
+               }
+               catch(LanguageException ex) {
+                       raisedException = true;
+                       if(raisedException!=expectedException)
+                               ex.printStackTrace();
+               }
+               catch(Exception ex2) {
+                       ex2.printStackTrace();
+                       throw new RuntimeException(ex2);
+               }
+
+               Assert.assertEquals("Expected exception does not match raised 
exception",
+                       expectedException, raisedException);
+       }
+
+       private void setTestConfig(String scriptFilename) throws 
FileNotFoundException {
+               int index = scriptFilename.lastIndexOf(".dml");
+               String testName = scriptFilename.substring(0, index > 0 ? index 
: scriptFilename.length());
+               TestConfiguration testConfig = new 
TestConfiguration(TEST_CLASS_DIR, testName, new String[] {});
+               addTestConfiguration(testName, testConfig);
+               loadTestConfiguration(testConfig);
+
+               DMLConfig conf = new DMLConfig(getCurConfigFile().getPath());
+               ConfigurationManager.setLocalConfig(conf);
+       }
+
+       private static String readScript(String scriptFilename) throws 
IOException {
+               return DMLScript.readDMLScript(true, HOME + scriptFilename);
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 342a26b..e0ef884 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -35,7 +35,7 @@ import org.apache.sysds.test.TestUtils;
 import java.util.Arrays;
 import java.util.Collection;
 
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
@@ -76,34 +76,40 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
 
        @Test
        public void federatedMultiplyCP() {
-               federatedTwoMatricesSingleNodeTest(TEST_NAME);
+               String[] expectedHeavyHitters = new String[]{"fed_*", 
"fed_fedinit", "fed_r'", "fed_ba+*"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME, 
expectedHeavyHitters);
        }
 
        @Test
        public void federatedRowSum(){
-               federatedTwoMatricesSingleNodeTest(TEST_NAME_2);
+               String[] expectedHeavyHitters = new String[]{"fed_*", "fed_r'", 
"fed_fedinit", "fed_ba+*", "fed_uark+"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_2, 
expectedHeavyHitters);
        }
 
        @Test
        public void federatedTernarySequence(){
-               federatedTwoMatricesSingleNodeTest(TEST_NAME_3);
+               String[] expectedHeavyHitters = new String[]{"fed_+*", 
"fed_1-*", "fed_fedinit", "fed_uak+"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_3, 
expectedHeavyHitters);
        }
 
        @Test
        public void federatedAggregateBinarySequence(){
                cols = rows;
-               federatedTwoMatricesSingleNodeTest(TEST_NAME_4);
+               String[] expectedHeavyHitters = new String[]{"fed_ba+*", 
"fed_*", "fed_fedinit"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_4, 
expectedHeavyHitters);
        }
 
        @Test
        public void federatedAggregateBinaryColFedSequence(){
                cols = rows;
-               federatedTwoMatricesSingleNodeTest(TEST_NAME_5);
+               String[] expectedHeavyHitters = new 
String[]{"fed_ba+*","fed_*","fed_fedinit"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_5, 
expectedHeavyHitters);
        }
 
        @Test
        public void federatedAggregateBinarySequence2(){
-               federatedTwoMatricesSingleNodeTest(TEST_NAME_6);
+               String[] expectedHeavyHitters = new 
String[]{"fed_ba+*","fed_fedinit"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_6, 
expectedHeavyHitters);
        }
 
        private void writeStandardMatrix(String matrixName, long seed){
@@ -166,11 +172,11 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                }
        }
 
-       private void federatedTwoMatricesSingleNodeTest(String testName){
-               federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName);
+       private void federatedTwoMatricesSingleNodeTest(String testName, 
String[] expectedHeavyHitters){
+               federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName, 
expectedHeavyHitters);
        }
 
-       private void federatedTwoMatricesTest(Types.ExecMode execMode, String 
testName) {
+       private void federatedTwoMatricesTest(Types.ExecMode execMode, String 
testName, String[] expectedHeavyHitters) {
                OptimizerUtils.FEDERATED_COMPILATION = true;
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                Types.ExecMode platformOld = rtplatform;
@@ -213,10 +219,9 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
 
                // compare via files
                compareResults(1e-9);
-               if ( testName.equals(TEST_NAME_3) )
-                       assertTrue(heavyHittersContainsString("fed_+*", 
"fed_1-*"));
-               else
-                       assertTrue(heavyHittersContainsString("fed_*", 
"fed_ba+*"));
+               if (!heavyHittersContainsAllString(expectedHeavyHitters))
+                       fail("The following expected heavy hitters are missing: 
"
+                               + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
 
                TestUtils.shutdownThreads(t1, t2);
 
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/BinaryCostEstimatorTest.dml 
b/src/test/scripts/functions/privacy/fedplanning/BinaryCostEstimatorTest.dml
new file mode 100644
index 0000000..1899614
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/BinaryCostEstimatorTest.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+a = matrix (1, rows=10, cols=10);
+b = matrix (2, rows=10, cols=10);
+c = a * b;
+print(toString(c));
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyCostEstimatorTest.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyCostEstimatorTest.dml
new file mode 100644
index 0000000..dfeacec
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyCostEstimatorTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+ph = "placeholder"
+rows = 10000
+cols = 10
+X = federated(addresses=list(ph, ph),
+              ranges=list(list(0, 0), list(rows / 2, cols), list(rows / 2, 0), 
list(rows, cols)))
+Y = federated(addresses=list(ph, ph),
+              ranges=list(list(0, 0), list(rows/2, cols), list(rows / 2, 0), 
list(rows, cols)))
+Z0 = X * Y
+Z = t(Z0) %*% X
+write(Z, ph)
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/ForLoopCostEstimatorTest.dml 
b/src/test/scripts/functions/privacy/fedplanning/ForLoopCostEstimatorTest.dml
new file mode 100644
index 0000000..b80e745
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/ForLoopCostEstimatorTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+for ( i in 1:5 ){
+    a = matrix (1, rows=10, cols=10);
+    b = matrix (2, rows=10, cols=10);
+    c = a * b;
+    print(toString(c));
+}
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FunctionCostEstimatorTest.dml 
b/src/test/scripts/functions/privacy/fedplanning/FunctionCostEstimatorTest.dml
new file mode 100644
index 0000000..1f0d876
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FunctionCostEstimatorTest.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+multiplication = function (){
+    a = matrix (1, rows=10, cols=10);
+    b = matrix (2, rows=10, cols=10);
+    c = a * b;
+    print(toString(c));
+}
+multiplication();
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/IfElseCostEstimatorTest.dml 
b/src/test/scripts/functions/privacy/fedplanning/IfElseCostEstimatorTest.dml
new file mode 100644
index 0000000..b0194e3
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/IfElseCostEstimatorTest.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ( 1 ){
+    a = matrix (1, rows=10, cols=10);
+    b = matrix (2, rows=10, cols=10);
+    c = a * b;
+    print(toString(c));
+}
+else {
+    print("No result");
+}
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/ParForLoopCostEstimatorTest.dml
 
b/src/test/scripts/functions/privacy/fedplanning/ParForLoopCostEstimatorTest.dml
new file mode 100644
index 0000000..32a49ec
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/ParForLoopCostEstimatorTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+parfor ( i in 1:5 ){
+    a = matrix (1, rows=10, cols=10);
+    b = matrix (2, rows=10, cols=10);
+    c = a * b;
+    print(toString(c));
+}
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/WhileCostEstimatorTest.dml 
b/src/test/scripts/functions/privacy/fedplanning/WhileCostEstimatorTest.dml
new file mode 100644
index 0000000..faea01d
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/WhileCostEstimatorTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+while ( 1 ){
+    a = matrix (1, rows=10, cols=10);
+    b = matrix (2, rows=10, cols=10);
+    c = a * b;
+    print(toString(c));
+}
\ No newline at end of file

Reply via email to