You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/09/10 18:50:56 UTC

[systemds] branch master updated: [SYSTEMDS-3018] Costing of Federated Execution Plans

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 6523457  [SYSTEMDS-3018] Costing of Federated Execution Plans
6523457 is described below

commit 65234573b986e8d0ba5bd27bd3d7641b921a7641
Author: sebwrede <sw...@know-center.at>
AuthorDate: Fri Jun 18 17:31:31 2021 +0200

    [SYSTEMDS-3018] Costing of Federated Execution Plans
    
    Closes #1367.
---
 src/main/java/org/apache/sysds/hops/Hop.java       |  93 +++++--
 .../codegen/opt/PlanSelectionFuseCostBasedV2.java  | 182 +-------------
 .../org/apache/sysds/hops/cost/ComputeCost.java    | 225 +++++++++++++++++
 .../sysds/hops/cost/CostEstimationWrapper.java     |  13 +-
 .../org/apache/sysds/hops/cost/CostEstimator.java  |  72 +++---
 .../hops/cost/CostEstimatorStaticRuntime.java      |  40 +--
 .../org/apache/sysds/hops/cost/FederatedCost.java  | 117 +++++++++
 .../sysds/hops/cost/FederatedCostEstimator.java    | 214 ++++++++++++++++
 .../apache/sysds/hops/rewrite/ProgramRewriter.java |   1 +
 .../hops/rewrite/RewriteFederatedExecution.java    | 133 ++++++++--
 .../rewrite/RewriteFederatedStatementBlocks.java   |  66 +++++
 .../runtime/instructions/FEDInstructionParser.java |   1 +
 .../fed/AggregateUnaryFEDInstruction.java          |  56 ++++-
 .../instructions/fed/AppendFEDInstruction.java     |   5 +-
 .../instructions/fed/CtableFEDInstruction.java     |   4 +-
 .../runtime/instructions/fed/FEDInstruction.java   |   3 +
 .../instructions/fed/ReorgFEDInstruction.java      |  36 ++-
 .../org/apache/sysds/test/AutomatedTestBase.java   |  26 ++
 .../fedplanning/FederatedCostEstimatorTest.java    | 279 +++++++++++++++++++++
 .../fedplanning/FederatedMultiplyPlanningTest.java |  33 +--
 .../fedplanning/BinaryCostEstimatorTest.dml        |  26 ++
 .../FederatedMultiplyCostEstimatorTest.dml         |  31 +++
 .../fedplanning/ForLoopCostEstimatorTest.dml       |  27 ++
 .../fedplanning/FunctionCostEstimatorTest.dml      |  28 +++
 .../fedplanning/IfElseCostEstimatorTest.dml        |  30 +++
 .../fedplanning/ParForLoopCostEstimatorTest.dml    |  27 ++
 .../privacy/fedplanning/WhileCostEstimatorTest.dml |  27 ++
 27 files changed, 1483 insertions(+), 312 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index ececf52..45fc3af 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -36,6 +36,7 @@ import org.apache.sysds.common.Types.OpOp2;
 import org.apache.sysds.common.Types.OpOpData;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.cost.FederatedCost;
 import org.apache.sysds.hops.recompile.Recompiler;
 import org.apache.sysds.hops.recompile.Recompiler.ResetType;
 import org.apache.sysds.lops.CSVReBlock;
@@ -91,8 +92,9 @@ public abstract class Hop implements ParseInfo {
 	 * If it is lout, the output should be retrieved by the coordinator.
 	 */
 	protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
+	protected FederatedCost _federatedCost = new FederatedCost();
 	
-	// Estimated size for the output produced from this Hop
+	// Estimated size for the output produced from this Hop in bytes
 	protected double _outputMemEstimate = OptimizerUtils.INVALID_SIZE;
 	
 	// Estimated size for the entire operation represented by this Hop
@@ -535,7 +537,7 @@ public abstract class Hop implements ParseInfo {
 	 * only use getMemEstimate(), which gives memory required to store 
 	 * all inputs and the output.
 	 * 
-	 * @return output size memory estimate
+	 * @return output size memory estimate in bytes
 	 */
 	protected double getOutputSize() {
 		return _outputMemEstimate;
@@ -545,14 +547,22 @@ public abstract class Hop implements ParseInfo {
 		return getInputSize(null);
 	}
 
-	protected double getInputSize(Collection<String> exclVars) {
+	/**
+	 * Get the memory estimate of inputs as the sum of input estimates in bytes.
+	 * @param exclVars name of input hops to exclude from the input estimate
+	 * @param injectedDefault default memory estimate (bytes) used when the memory estimate of the input is negative
+	 * @return input memory estimate in bytes
+	 */
+	protected double getInputSize(Collection<String> exclVars, double injectedDefault){
 		double sum = 0;
 		int len = _input.size();
 		for( int i=0; i<len; i++ ) { //for all inputs
 			Hop hi = _input.get(i);
 			if( exclVars != null && exclVars.contains(hi.getName()) )
 				continue;
-			double hmout = hi.getOutputMemEstimate();
+			double hmout = hi.getOutputMemEstimate(injectedDefault);
+			if (hmout < 0)
+				hmout = injectedDefault*(Math.max(hi.getDim1(),1) * Math.max(hi.getDim2(),1));
 			if( hmout > 1024*1024 ) {//for relevant sizes
 				//check if already included in estimate (if an input is used
 				//multiple times it is still only required once in memory)
@@ -564,10 +574,19 @@ public abstract class Hop implements ParseInfo {
 			}
 			sum += hmout;
 		}
-		
+
 		return sum;
 	}
 
+	/**
+	 * Get the memory estimate of inputs as the sum of input estimates in bytes.
+	 * @param exclVars name of input hops to exclude from the input estimate
+	 * @return input memory estimate in bytes
+	 */
+	protected double getInputSize(Collection<String> exclVars) {
+		return getInputSize(exclVars, OptimizerUtils.INVALID_SIZE);
+	}
+
 	protected double getInputSize( int pos ){
 		double ret = 0;
 		if( _input.size()>pos )
@@ -582,12 +601,11 @@ public abstract class Hop implements ParseInfo {
 	/**
 	 * NOTES:
 	 * * Purpose: Whenever the output dimensions / sparsity of a hop are unknown, this hop
-	 *   should store its worst-case output statistics (if known) in that table. Subsequent
-	 *   hops can then
+	 *   should store its worst-case output statistics (if known) in that table.
 	 * * Invocation: Intended to be called for ALL root nodes of one Hops DAG with the same
 	 *   (initially empty) memo table.
 	 * 
-	 * @return memory estimate
+	 * @return memory estimate in bytes
 	 */
 	public double getMemEstimate() {
 		if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
@@ -620,17 +638,46 @@ public abstract class Hop implements ParseInfo {
 	}
 
 	//wrappers for meaningful public names to memory estimates.
-	
+
+	/**
+	 * Get the memory estimate of inputs as the sum of input estimates in bytes.
+	 * @return input memory estimate in bytes
+	 */
 	public double getInputMemEstimate()
 	{
 		return getInputSize();
 	}
-	
+
+	/**
+	 * Get the memory estimate of inputs as the sum of input estimates in bytes.
+	 * @param injectedDefault default memory estimate (bytes) used when the memory estimate of the input is negative
+	 * @return input memory estimate in bytes
+	 */
+	public double getInputMemEstimate(double injectedDefault){
+		return getInputSize(null, injectedDefault);
+	}
+
+	/**
+	 * Output memory estimate in bytes.
+	 * @return output memory estimate in bytes
+	 */
 	public double getOutputMemEstimate()
 	{
 		return getOutputSize();
 	}
 
+	/**
+	 * Output memory estimate in bytes with negative memory estimates replaced by the injected default.
+	 * The injected default represents the memory estimate per output cell, hence it is multiplied by the estimated
+	 * dimensions of the output of the hop.
+	 * @param injectedDefault memory estimate to be returned in case the memory estimate defaults to a negative number
+	 * @return output memory estimate in bytes
+	 */
+	public double getOutputMemEstimate(double injectedDefault)
+	{
+		return Math.max(getOutputMemEstimate(),injectedDefault*(Math.max(getDim1(),1) * Math.max(getDim2(),1)));
+	}
+
 	public double getIntermediateMemEstimate()
 	{
 		return getIntermediateSize();
@@ -823,17 +870,13 @@ public abstract class Hop implements ParseInfo {
 	 * This method only has an effect if FEDERATED_COMPILATION is activated.
 	 */
 	protected void updateETFed(){
-		if ( _federatedOutput == FederatedOutput.FOUT || _federatedOutput == FederatedOutput.LOUT )
+		if ( _federatedOutput.isForced() )
 			_etype = ExecType.FED;
 	}
 	
 	public boolean isFederated(){
 		return getExecType() == ExecType.FED;
 	}
-	
-	public boolean isFederatedOutput(){
-		return _federatedOutput == FederatedOutput.FOUT;
-	}
 
 	public boolean someInputFederated(){
 		return getInput().stream().anyMatch(Hop::hasFederatedOutput);
@@ -889,6 +932,26 @@ public abstract class Hop implements ParseInfo {
 		return _federatedOutput == FederatedOutput.FOUT;
 	}
 
+	public boolean hasLocalOutput(){
+		return _federatedOutput == FederatedOutput.LOUT;
+	}
+
+	/**
+	 * Check if federated cost has been initialized for this Hop.
+	 * @return true if federated cost has been initialized
+	 */
+	public boolean federatedCostInitialized(){
+		return _federatedCost.getTotal() > 0;
+	}
+
+	public FederatedCost getFederatedCost(){
+		return _federatedCost;
+	}
+
+	public void setFederatedCost(FederatedCost cost){
+		_federatedCost = cost;
+	}
+
 	public void setUpdateType(UpdateType update){
 		_updateType = update;
 	}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 0b20876..c6cfe9f 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -46,17 +46,8 @@ import org.apache.sysds.common.Types.OpOpData;
 import org.apache.sysds.common.Types.OpOpN;
 import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.hops.AggUnaryOp;
-import org.apache.sysds.hops.BinaryOp;
-import org.apache.sysds.hops.DnnOp;
 import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.IndexingOp;
-import org.apache.sysds.hops.LiteralOp;
-import org.apache.sysds.hops.NaryOp;
 import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.hops.ParameterizedBuiltinOp;
-import org.apache.sysds.hops.ReorgOp;
-import org.apache.sysds.hops.TernaryOp;
-import org.apache.sysds.hops.UnaryOp;
 import org.apache.sysds.hops.codegen.opt.ReachabilityGraph.SubProblem;
 import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
 import org.apache.sysds.hops.codegen.template.TemplateOuterProduct;
@@ -64,6 +55,7 @@ import org.apache.sysds.hops.codegen.template.TemplateRow;
 import org.apache.sysds.hops.codegen.template.TemplateUtils;
 import org.apache.sysds.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
 import org.apache.sysds.hops.codegen.template.TemplateBase.TemplateType;
+import org.apache.sysds.hops.cost.ComputeCost;
 import org.apache.sysds.hops.rewrite.HopRewriteUtils;
 import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
 import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer;
@@ -175,8 +167,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 			//obtain hop compute costs per cell once
 			HashMap<Long, Double> computeCosts = new HashMap<>();
 			for( Long hopID : part.getPartition() )
-				getComputeCosts(memo.getHopRefs().get(hopID), computeCosts);
-			
+				computeCosts.put(hopID, ComputeCost.getHOPComputeCost(memo.getHopRefs().get(hopID)));
+
 			//prepare pruning helpers and prune memo table w/ determined mat points
 			StaticCosts costs = new StaticCosts(computeCosts, sumComputeCost(computeCosts),
 				getReadCost(part, memo), getWriteCost(part.getRoots(), memo), minOuterSparsity(part, memo));
@@ -1011,174 +1003,6 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 		return costs;
 	}
 	
-	private static void getComputeCosts(Hop current, HashMap<Long, Double> computeCosts) 
-	{
-		//get costs for given hop
-		double costs = 1;
-		if( current instanceof UnaryOp ) {
-			switch( ((UnaryOp)current).getOp() ) {
-				case ABS:
-				case ROUND:
-				case CEIL:
-				case FLOOR:
-				case SIGN:    costs = 1; break; 
-				case SPROP:
-				case SQRT:    costs = 2; break;
-				case EXP:     costs = 18; break;
-				case SIGMOID: costs = 21; break;
-				case LOG:
-				case LOG_NZ:  costs = 32; break;
-				case NCOL:
-				case NROW:
-				case PRINT:
-				case ASSERT:
-				case CAST_AS_BOOLEAN:
-				case CAST_AS_DOUBLE:
-				case CAST_AS_INT:
-				case CAST_AS_MATRIX:
-				case CAST_AS_SCALAR: costs = 1; break;
-				case SIN:     costs = 18; break;
-				case COS:     costs = 22; break;
-				case TAN:     costs = 42; break;
-				case ASIN:    costs = 93; break;
-				case ACOS:    costs = 103; break;
-				case ATAN:    costs = 40; break;
-				case SINH:    costs = 93; break; // TODO:
-				case COSH:    costs = 103; break;
-				case TANH:    costs = 40; break;
-				case CUMSUM:
-				case CUMMIN:
-				case CUMMAX:
-				case CUMPROD: costs = 1; break;
-				case CUMSUMPROD: costs = 2; break;
-				default:
-					LOG.warn("Cost model not "
-						+ "implemented yet for: "+((UnaryOp)current).getOp());
-			}
-		}
-		else if( current instanceof BinaryOp ) {
-			switch( ((BinaryOp)current).getOp() ) {
-				case MULT: 
-				case PLUS:
-				case MINUS:
-				case MIN:
-				case MAX: 
-				case AND:
-				case OR:
-				case EQUAL:
-				case NOTEQUAL:
-				case LESS:
-				case LESSEQUAL:
-				case GREATER:
-				case GREATEREQUAL: 
-				case CBIND:
-				case RBIND:   costs = 1; break;
-				case INTDIV:  costs = 6; break;
-				case MODULUS: costs = 8; break;
-				case DIV:     costs = 22; break;
-				case LOG:
-				case LOG_NZ:  costs = 32; break;
-				case POW:     costs = (HopRewriteUtils.isLiteralOfValue(
-						current.getInput().get(1), 2) ? 1 : 16); break;
-				case MINUS_NZ:
-				case MINUS1_MULT: costs = 2; break;
-				case MOMENT:
-					int type = (int) (current.getInput().get(1) instanceof LiteralOp ? 
-						HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
-					switch( type ) {
-						case 0: costs = 1; break; //count
-						case 1: costs = 8; break; //mean
-						case 2: costs = 16; break; //cm2
-						case 3: costs = 31; break; //cm3
-						case 4: costs = 51; break; //cm4
-						case 5: costs = 16; break; //variance
-					}
-					break;
-				case COV: costs = 23; break;
-				default:
-					LOG.warn("Cost model not "
-						+ "implemented yet for: "+((BinaryOp)current).getOp());
-			}
-		}
-		else if( current instanceof TernaryOp ) {
-			switch( ((TernaryOp)current).getOp() ) {
-				case IFELSE:
-				case PLUS_MULT: 
-				case MINUS_MULT: costs = 2; break;
-				case CTABLE:     costs = 3; break;
-				case MOMENT:
-					int type = (int) (current.getInput().get(1) instanceof LiteralOp ? 
-						HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
-					switch( type ) {
-						case 0: costs = 2; break; //count
-						case 1: costs = 9; break; //mean
-						case 2: costs = 17; break; //cm2
-						case 3: costs = 32; break; //cm3
-						case 4: costs = 52; break; //cm4
-						case 5: costs = 17; break; //variance
-					}
-					break;
-				case COV: costs = 23; break;
-				default:
-					LOG.warn("Cost model not "
-						+ "implemented yet for: "+((TernaryOp)current).getOp());
-			}
-		}
-		else if( current instanceof NaryOp ) {
-			costs = HopRewriteUtils.isNary(current, OpOpN.MIN, OpOpN.MAX, OpOpN.PLUS) ?
-				current.getInput().size() : 1;
-		}
-		else if( current instanceof ParameterizedBuiltinOp ) {
-			costs = 1;
-		}
-		else if( current instanceof IndexingOp ) {
-			costs = 1;
-		}
-		else if( current instanceof ReorgOp ) {
-			costs = 1;
-		}
-		else if( current instanceof DnnOp ) {
-			switch( ((DnnOp)current).getOp() ) {
-				case BIASADD:
-				case BIASMULT:
-					costs = 2;
-				default:
-					LOG.warn("Cost model not "
-						+ "implemented yet for: "+((DnnOp)current).getOp());
-			}
-		}
-		else if( current instanceof AggBinaryOp ) {
-			//outer product template w/ matrix-matrix 
-			//or row template w/ matrix-vector or matrix-matrix
-			costs = 2 * current.getInput().get(0).getDim2();
-			if( current.getInput().get(0).dimsKnown(true) )
-				costs *= current.getInput().get(0).getSparsity();
-		}
-		else if( current instanceof AggUnaryOp) {
-			switch(((AggUnaryOp)current).getOp()) {
-				case SUM:    costs = 4; break; 
-				case SUM_SQ: costs = 5; break;
-				case MIN:
-				case MAX:    costs = 1; break;
-				default:
-					LOG.warn("Cost model not "
-						+ "implemented yet for: "+((AggUnaryOp)current).getOp());
-			}
-			switch(((AggUnaryOp)current).getDirection()) {
-				case Col: costs *= Math.max(current.getInput().get(0).getDim1(),1); break;
-				case Row: costs *= Math.max(current.getInput().get(0).getDim2(),1); break;
-				case RowCol: costs *= getSize(current.getInput().get(0)); break;
-			}
-		}
-		
-		//scale by current output size in order to correctly reflect
-		//a mix of row and cell operations in the same fused operator
-		//(e.g., row template with fused column vector operations)
-		costs *= getSize(current);
-		
-		computeCosts.put(current.getHopID(), costs);
-	}
-	
 	private static boolean hasNoRefToMatPoint(long hopID, 
 			MemoTableEntry me, InterestingPoint[] M, boolean[] plan) {
 		return !InterestingPoint.isMatPoint(M, hopID, me, plan);
diff --git a/src/main/java/org/apache/sysds/hops/cost/ComputeCost.java b/src/main/java/org/apache/sysds/hops/cost/ComputeCost.java
new file mode 100644
index 0000000..3ac64b6
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/ComputeCost.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DnnOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.IndexingOp;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.NaryOp;
+import org.apache.sysds.hops.ParameterizedBuiltinOp;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
+import org.apache.sysds.hops.UnaryOp;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+
+/**
+ * Class with methods estimating compute costs of operations.
+ */
+public class ComputeCost {
+	private static final Log LOG = LogFactory.getLog(ComputeCost.class.getName());
+
+	/**
+	 * Get compute cost for given HOP based on the number of floating point operations per output cell
+	 * and the total number of output cells.
+	 * @param currentHop for which compute cost is returned
+	 * @return compute cost of currentHop as number of floating point operations
+	 */
+	public static double getHOPComputeCost(Hop currentHop){
+		double costs = 1;
+		if( currentHop instanceof UnaryOp) {
+			switch( ((UnaryOp)currentHop).getOp() ) {
+				case ABS:
+				case ROUND:
+				case CEIL:
+				case FLOOR:
+				case SIGN:    costs = 1; break;
+				case SPROP:
+				case SQRT:    costs = 2; break;
+				case EXP:     costs = 18; break;
+				case SIGMOID: costs = 21; break;
+				case LOG:
+				case LOG_NZ:  costs = 32; break;
+				case NCOL:
+				case NROW:
+				case PRINT:
+				case ASSERT:
+				case CAST_AS_BOOLEAN:
+				case CAST_AS_DOUBLE:
+				case CAST_AS_INT:
+				case CAST_AS_MATRIX:
+				case CAST_AS_SCALAR: costs = 1; break;
+				case SIN:     costs = 18; break;
+				case COS:     costs = 22; break;
+				case TAN:     costs = 42; break;
+				case ASIN:    costs = 93; break;
+				case ACOS:    costs = 103; break;
+				case ATAN:    costs = 40; break;
+				case SINH:    costs = 93; break; // TODO:
+				case COSH:    costs = 103; break;
+				case TANH:    costs = 40; break;
+				case CUMSUM:
+				case CUMMIN:
+				case CUMMAX:
+				case CUMPROD: costs = 1; break;
+				case CUMSUMPROD: costs = 2; break;
+				default:
+					LOG.warn("Cost model not "
+						+ "implemented yet for: "+((UnaryOp)currentHop).getOp());
+			}
+		}
+		else if( currentHop instanceof BinaryOp) {
+			switch( ((BinaryOp)currentHop).getOp() ) {
+				case MULT:
+				case PLUS:
+				case MINUS:
+				case MIN:
+				case MAX:
+				case AND:
+				case OR:
+				case EQUAL:
+				case NOTEQUAL:
+				case LESS:
+				case LESSEQUAL:
+				case GREATER:
+				case GREATEREQUAL:
+				case CBIND:
+				case RBIND:   costs = 1; break;
+				case INTDIV:  costs = 6; break;
+				case MODULUS: costs = 8; break;
+				case DIV:     costs = 22; break;
+				case LOG:
+				case LOG_NZ:  costs = 32; break;
+				case POW:     costs = (HopRewriteUtils.isLiteralOfValue(
+					currentHop.getInput().get(1), 2) ? 1 : 16); break;
+				case MINUS_NZ:
+				case MINUS1_MULT: costs = 2; break;
+				case MOMENT:
+					int type = (int) (currentHop.getInput().get(1) instanceof LiteralOp ?
+						HopRewriteUtils.getIntValueSafe((LiteralOp)currentHop.getInput().get(1)) : 2);
+					switch( type ) {
+						case 0: costs = 1; break; //count
+						case 1: costs = 8; break; //mean
+						case 2: costs = 16; break; //cm2
+						case 3: costs = 31; break; //cm3
+						case 4: costs = 51; break; //cm4
+						case 5: costs = 16; break; //variance
+					}
+					break;
+				case COV: costs = 23; break;
+				default:
+					LOG.warn("Cost model not "
+						+ "implemented yet for: "+((BinaryOp)currentHop).getOp());
+			}
+		}
+		else if( currentHop instanceof TernaryOp) {
+			switch( ((TernaryOp)currentHop).getOp() ) {
+				case IFELSE:
+				case PLUS_MULT:
+				case MINUS_MULT: costs = 2; break;
+				case CTABLE:     costs = 3; break;
+				case MOMENT:
+					int type = (int) (currentHop.getInput().get(1) instanceof LiteralOp ?
+						HopRewriteUtils.getIntValueSafe((LiteralOp)currentHop.getInput().get(1)) : 2);
+					switch( type ) {
+						case 0: costs = 2; break; //count
+						case 1: costs = 9; break; //mean
+						case 2: costs = 17; break; //cm2
+						case 3: costs = 32; break; //cm3
+						case 4: costs = 52; break; //cm4
+						case 5: costs = 17; break; //variance
+					}
+					break;
+				case COV: costs = 23; break;
+				default:
+					LOG.warn("Cost model not "
+						+ "implemented yet for: "+((TernaryOp)currentHop).getOp());
+			}
+		}
+		else if( currentHop instanceof NaryOp) {
+			costs = HopRewriteUtils.isNary(currentHop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) ?
+				currentHop.getInput().size() : 1;
+		}
+		else if( currentHop instanceof ParameterizedBuiltinOp) {
+			costs = 1;
+		}
+		else if( currentHop instanceof IndexingOp) {
+			costs = 1;
+		}
+		else if( currentHop instanceof ReorgOp) {
+			costs = 1;
+		}
+		else if( currentHop instanceof DnnOp) {
+			switch( ((DnnOp)currentHop).getOp() ) {
+				case BIASADD:
+				case BIASMULT:
+					costs = 2;
+				default:
+					LOG.warn("Cost model not "
+						+ "implemented yet for: "+((DnnOp)currentHop).getOp());
+			}
+		}
+		else if( currentHop instanceof AggBinaryOp) {
+			//outer product template w/ matrix-matrix
+			//or row template w/ matrix-vector or matrix-matrix
+			costs = 2 * currentHop.getInput().get(0).getDim2();
+			if( currentHop.getInput().get(0).dimsKnown(true) )
+				costs *= currentHop.getInput().get(0).getSparsity();
+		}
+		else if( currentHop instanceof AggUnaryOp) {
+			switch(((AggUnaryOp)currentHop).getOp()) {
+				case SUM:    costs = 4; break;
+				case SUM_SQ: costs = 5; break;
+				case MIN:
+				case MAX:    costs = 1; break;
+				default:
+					LOG.warn("Cost model not "
+						+ "implemented yet for: "+((AggUnaryOp)currentHop).getOp());
+			}
+			switch(((AggUnaryOp)currentHop).getDirection()) {
+				case Col: costs *= Math.max(currentHop.getInput().get(0).getDim1(),1); break;
+				case Row: costs *= Math.max(currentHop.getInput().get(0).getDim2(),1); break;
+				case RowCol: costs *= getSize(currentHop.getInput().get(0)); break;
+			}
+		}
+
+		//scale by current output size in order to correctly reflect
+		//a mix of row and cell operations in the same fused operator
+		//(e.g., row template with fused column vector operations)
+		costs *= getSize(currentHop);
+		return costs;
+	}
+
+	/**
+	 * Get number of output cells of given hop.
+	 * @param hop for which the number of output cells are found
+	 * @return number of output cells of given hop
+	 */
+	private static long getSize(Hop hop) {
+		return Math.max(hop.getDim1(),1)
+			* Math.max(hop.getDim2(),1);
+	}
+}
diff --git a/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java b/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java
index f8d6a2d..23fcf51 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java
@@ -32,7 +32,6 @@ import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 
 public class CostEstimationWrapper 
 {
-	
 	public enum CostType { 
 		NUM_MRJOBS, //based on number of MR jobs, [number MR jobs]
 		STATIC // based on FLOPS, read/write, etc, [time in sec]
@@ -44,17 +43,13 @@ public class CostEstimationWrapper
 	private static CostEstimator _costEstim = null;
 	
 	
-	static 
-	{
-
+	static  {
 		//create cost estimator
-		try
-		{
+		try {
 			//TODO config parameter?
 			_costEstim = createCostEstimator(DEFAULT_COSTTYPE);
 		}
-		catch(Exception ex)
-		{
+		catch(Exception ex) {
 			LOG.error("Failed cost estimator initialization.", ex);
 		}
 	}
@@ -89,5 +84,5 @@ public class CostEstimationWrapper
 			default:
 				throw new DMLRuntimeException("Unknown cost type: "+type);
 		}
-	}	
+	}
 }
diff --git a/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java b/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
index bb8753c..03948d4 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
@@ -58,9 +58,8 @@ import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
 
-public abstract class CostEstimator 
+public abstract class CostEstimator
 {
-	
 	protected static final Log LOG = LogFactory.getLog(CostEstimator.class.getName());
 	
 	private static final int DEFAULT_NUMITER = 15;
@@ -84,7 +83,7 @@ public abstract class CostEstimator
 	public double getTimeEstimate(ProgramBlock pb, LocalVariableMap vars, HashMap<String,VarStats> stats, boolean recursive) {
 		//obtain stats from symboltable (e.g., during recompile)
 		maintainVariableStatistics(vars, stats);
-				
+		
 		//get cost estimate
 		return rGetTimeEstimate(pb, stats, new HashSet<String>(), recursive);
 	}
@@ -281,10 +280,8 @@ public abstract class CostEstimator
 		VarStats[] vs = new VarStats[3];
 		String[] attr = null; 
 
-		if( inst instanceof UnaryCPInstruction )
-		{
-			if( inst instanceof DataGenCPInstruction )
-			{
+		if( inst instanceof UnaryCPInstruction ) {
+			if( inst instanceof DataGenCPInstruction ) {
 				DataGenCPInstruction rinst = (DataGenCPInstruction) inst;
 				vs[0] = _unknownStats;
 				vs[1] = _unknownStats;
@@ -298,15 +295,13 @@ public abstract class CostEstimator
 					type = 1;
 				attr = new String[]{String.valueOf(type)};
 			}
-			else if( inst instanceof StringInitCPInstruction )
-			{
+			else if( inst instanceof StringInitCPInstruction ) {
 				StringInitCPInstruction rinst = (StringInitCPInstruction) inst;
 				vs[0] = _unknownStats;
 				vs[1] = _unknownStats;
 				vs[2] = stats.get( rinst.output.getName() );
 			}
-			else //general unary
-			{
+			else { //general unary
 				UnaryCPInstruction uinst = (UnaryCPInstruction) inst;
 				vs[0] = stats.get( uinst.input1.getName() );
 				vs[1] = _unknownStats;
@@ -317,69 +312,61 @@ public abstract class CostEstimator
 				if( vs[2] == null ) //scalar output
 					vs[2] = _scalarStats;
 				
-				if( inst instanceof MMTSJCPInstruction )
-				{
+				if( inst instanceof MMTSJCPInstruction ) {
 					String type = ((MMTSJCPInstruction)inst).getMMTSJType().toString();
 					attr = new String[]{type};
 				} 
-				else if( inst instanceof AggregateUnaryCPInstruction )
-				{
+				else if( inst instanceof AggregateUnaryCPInstruction ) {
 					String[] parts = InstructionUtils.getInstructionParts(inst.toString());
 					String opcode = parts[0];
 					if( opcode.equals("cm") )
-						attr = new String[]{parts[parts.length-2]};						
-				} 
+						attr = new String[]{parts[parts.length-2]};
+				}
 			}
 		}
-		else if( inst instanceof BinaryCPInstruction )
-		{
+		else if( inst instanceof BinaryCPInstruction ) {
 			BinaryCPInstruction binst = (BinaryCPInstruction) inst;
 			vs[0] = stats.get( binst.input1.getName() );
 			vs[1] = stats.get( binst.input2.getName() );
 			vs[2] = stats.get( binst.output.getName() );
 			
-			
-			if( vs[0] == null ) //scalar input, 
+			if( vs[0] == null ) //scalar input,
 				vs[0] = _scalarStats;
-			if( vs[1] == null ) //scalar input, 
+			if( vs[1] == null ) //scalar input,
 				vs[1] = _scalarStats;
 			if( vs[2] == null ) //scalar output
 				vs[2] = _scalarStats;
-		}	
-		else if( inst instanceof AggregateTernaryCPInstruction )
-		{
+		}
+		else if( inst instanceof AggregateTernaryCPInstruction ) {
 			AggregateTernaryCPInstruction binst = (AggregateTernaryCPInstruction) inst;
 			//of same dimension anyway but missing third input
-			vs[0] = stats.get( binst.input1.getName() ); 
+			vs[0] = stats.get( binst.input1.getName() );
 			vs[1] = stats.get( binst.input2.getName() );
 			vs[2] = stats.get( binst.output.getName() );
 				
-			if( vs[0] == null ) //scalar input, 
+			if( vs[0] == null ) //scalar input,
 				vs[0] = _scalarStats;
-			if( vs[1] == null ) //scalar input, 
+			if( vs[1] == null ) //scalar input,
 				vs[1] = _scalarStats;
 			if( vs[2] == null ) //scalar output
 				vs[2] = _scalarStats;
 		}
-		else if( inst instanceof ParameterizedBuiltinCPInstruction )
-		{
+		else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
 			//ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction) inst;
 			String[] parts = InstructionUtils.getInstructionParts(inst.toString());
 			String opcode = parts[0];
-			if( opcode.equals("groupedagg") )
-			{				
+			if( opcode.equals("groupedagg") ) {
 				HashMap<String,String> paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
 				String fn = paramsMap.get("fn");
 				String order = paramsMap.get("order");
 				AggregateOperationTypes type = CMOperator.getAggOpType(fn, order);
 				attr = new String[]{String.valueOf(type.ordinal())};
 			}
-			else if( opcode.equals("rmempty") )
-			{
+			else if( opcode.equals("rmempty") ) {
 				HashMap<String,String> paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
 				attr = new String[]{String.valueOf(paramsMap.get("margin").equals("rows")?0:1)};
 			}
-				
+			
 			vs[0] = stats.get( parts[1].substring(7).replaceAll(Lop.VARIABLE_NAME_PLACEHOLDER, "") );
 			vs[1] = _unknownStats; //TODO
 			vs[2] = stats.get( parts[parts.length-1] );
@@ -389,16 +376,14 @@ public abstract class CostEstimator
 			if( vs[2] == null ) //scalar output
 				vs[2] = _scalarStats;
 		}
-		else if( inst instanceof MultiReturnBuiltinCPInstruction )
-		{
+		else if( inst instanceof MultiReturnBuiltinCPInstruction ) {
 			//applies to qr, lu, eigen (cost computation on input1)
 			MultiReturnBuiltinCPInstruction minst = (MultiReturnBuiltinCPInstruction) inst;
 			vs[0] = stats.get( minst.input1.getName() );
 			vs[1] = stats.get( minst.getOutput(0).getName() );
 			vs[2] = stats.get( minst.getOutput(1).getName() );
 		}
-		else if( inst instanceof VariableCPInstruction )
-		{
+		else if( inst instanceof VariableCPInstruction ) {
 			setUnknownStats(vs);
 			
 			VariableCPInstruction varinst = (VariableCPInstruction) inst;
@@ -407,11 +392,10 @@ public abstract class CostEstimator
 				if( stats.containsKey( varinst.getInput1().getName() ) )
 					vs[0] = stats.get( varinst.getInput1().getName() );	
 				attr = new String[]{varinst.getInput3().getName()};
-			}	
+			}
 		}
-		else
-		{
-			setUnknownStats(vs);		
+		else {
+			setUnknownStats(vs);
 		}
 		
 		//maintain var status (CP output always inmem)
@@ -426,7 +410,7 @@ public abstract class CostEstimator
 	private static void setUnknownStats(VarStats[] vs) {
 		vs[0] = _unknownStats;
 		vs[1] = _unknownStats;
-		vs[2] = _unknownStats;	
+		vs[2] = _unknownStats;
 	}
 		
 	private static long getNumIterations(HashMap<String,VarStats> stats, ForProgramBlock pb) {
diff --git a/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java b/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java
index e2a4a75..3f29132 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java
@@ -344,30 +344,30 @@ public class CostEstimatorStaticRuntime extends CostEstimator
 						}
 						return (leftSparse) ? xcm * (d1m * d1s + 1) : xcm * d1m;
 					}
-				    else if( optype.equals("uatrace") || optype.equals("uaktrace") )
-				    	return 2 * d1m * d1n;
-				    else if( optype.equals("ua+") || optype.equals("uar+") || optype.equals("uac+")  ){
-				    	//sparse safe operations
-				    	if( !leftSparse ) //dense
-				    		return d1m * d1n;
-				    	else //sparse
-				    		return d1m * d1n * d1s;
-				    }
-				    else if( optype.equals("uak+") || optype.equals("uark+") || optype.equals("uack+"))
-				    	return 4 * d1m * d1n; //1*k+
-				    else if( optype.equals("uasqk+") || optype.equals("uarsqk+") || optype.equals("uacsqk+"))
+					else if( optype.equals("uatrace") || optype.equals("uaktrace") )
+						return 2 * d1m * d1n;
+					else if( optype.equals("ua+") || optype.equals("uar+") || optype.equals("uac+")  ){
+						//sparse safe operations
+						if( !leftSparse ) //dense
+							return d1m * d1n;
+						else //sparse
+							return d1m * d1n * d1s;
+					}
+					else if( optype.equals("uak+") || optype.equals("uark+") || optype.equals("uack+"))
+						return 4 * d1m * d1n; //1*k+
+					else if( optype.equals("uasqk+") || optype.equals("uarsqk+") || optype.equals("uacsqk+"))
 						return 5 * d1m * d1n; // +1 for multiplication to square term
-				    else if( optype.equals("uamean") || optype.equals("uarmean") || optype.equals("uacmean"))
+					else if( optype.equals("uamean") || optype.equals("uarmean") || optype.equals("uacmean"))
 						return 7 * d1m * d1n; //1*k+
-				    else if( optype.equals("uavar") || optype.equals("uarvar") || optype.equals("uacvar"))
+					else if( optype.equals("uavar") || optype.equals("uarvar") || optype.equals("uacvar"))
 						return 14 * d1m * d1n;
-				    else if(   optype.equals("uamax") || optype.equals("uarmax") || optype.equals("uacmax")
-				    		|| optype.equals("uamin") || optype.equals("uarmin") || optype.equals("uacmin")
-				    		|| optype.equals("uarimax") || optype.equals("ua*") )
-				    	return d1m * d1n;
+					else if(   optype.equals("uamax") || optype.equals("uarmax") || optype.equals("uacmax")
+						|| optype.equals("uamin") || optype.equals("uarmin") || optype.equals("uacmin")
+						|| optype.equals("uarimax") || optype.equals("ua*") )
+						return d1m * d1n;
 					
-				    return 0;
-				    
+					return 0;
+				
 				case Binary: //opcodes: +, -, *, /, ^ (incl. ^2, *2),
 					//max, min, solve, ==, !=, <, >, <=, >=  
 					//note: all relational ops are not sparsesafe
diff --git a/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java b/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
new file mode 100644
index 0000000..f4f8db4
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+/**
+ * Class storing execution cost estimates for federated executions with cost estimates split into different categories
+ * such as compute, read, and transfer cost.
+ */
+public class FederatedCost {
+	protected double _computeCost = 0;
+	protected double _readCost = 0;
+	protected double _inputTransferCost = 0;
+	protected double _outputTransferCost = 0;
+	protected double _inputTotalCost = 0;
+
+	public FederatedCost(){}
+
+	public FederatedCost(double readCost, double inputTransferCost, double outputTransferCost,
+		double computeCost, double inputTotalCost){
+		_readCost = readCost;
+		_inputTransferCost = inputTransferCost;
+		_outputTransferCost = outputTransferCost;
+		_computeCost = computeCost;
+		_inputTotalCost = inputTotalCost;
+	}
+
+	/**
+	 * Get the total sum of costs stored in this object.
+	 * @return total cost
+	 */
+	public double getTotal(){
+		return _computeCost + _readCost + _inputTransferCost + _outputTransferCost + _inputTotalCost;
+	}
+
+	/**
+	 * Multiply the input costs by the number of times the costs are repeated.
+	 * @param repetitionNumber number of repetitions of the costs
+	 */
+	public void addRepetitionCost(int repetitionNumber){
+		_inputTotalCost *= repetitionNumber;
+	}
+
+	/**
+	 * Get summed input costs.
+	 * @return summed input costs
+	 */
+	public double getInputTotalCost(){
+		return _inputTotalCost;
+	}
+
+	public void setInputTotalCost(double inputTotalCost){
+		_inputTotalCost = inputTotalCost;
+	}
+
+	/**
+	 * Add cost to the stored input cost.
+	 * @param additionalCost to add to total input cost
+	 */
+	public void addInputTotalCost(double additionalCost){
+		_inputTotalCost += additionalCost;
+	}
+
+	/**
+	 * Add total of federatedCost to stored inputTotalCost.
+	 * @param federatedCost input cost from which the total is retrieved
+	 */
+	public void addInputTotalCost(FederatedCost federatedCost){
+		_inputTotalCost += federatedCost.getTotal();
+	}
+
+	/**
+	 * Add costs of FederatedCost object to this object's current costs.
+	 * @param additionalCost object to add to this object
+	 */
+	public void addFederatedCost(FederatedCost additionalCost){
+		_readCost += additionalCost._readCost;
+		_inputTransferCost += additionalCost._inputTransferCost;
+		_outputTransferCost += additionalCost._outputTransferCost;
+		_computeCost += additionalCost._computeCost;
+		_inputTotalCost += additionalCost._inputTotalCost;
+	}
+
+	@Override
+	public String toString(){
+		StringBuilder builder = new StringBuilder();
+		builder.append(" computeCost: ");
+		builder.append(_computeCost);
+		builder.append("\n readCost: ");
+		builder.append(_readCost);
+		builder.append("\n inputTransferCost: ");
+		builder.append(_inputTransferCost);
+		builder.append("\n outputTransferCost: ");
+		builder.append(_outputTransferCost);
+		builder.append("\n inputTotalCost: ");
+		builder.append(_inputTotalCost);
+		builder.append("\n total cost: ");
+		builder.append(getTotal());
+		return builder.toString();
+	}
+}
diff --git a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
new file mode 100644
index 0000000..3e2f994
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+
+import java.util.ArrayList;
+
+/**
+ * Cost estimator for federated executions with methods and constants for going through DML programs to estimate costs.
+ */
+public class FederatedCostEstimator {
+	public int DEFAULT_MEMORY_ESTIMATE = 8;
+	public int DEFAULT_ITERATION_NUMBER = 15;
+	public double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 1024*1024*1024; //Default network bandwidth in bytes per second
+	public double WORKER_COMPUTE_BANDWITH_FLOPS = 2.5*1024*1024*1024; //Default compute bandwidth in FLOPS
+	public double WORKER_DEGREE_OF_PARALLELISM = 8; //Default number of parallel processes for workers
+	public double WORKER_READ_BANDWIDTH_BYTES_PS = 3.5*1024*1024*1024; //Default read bandwidth in bytes per second
+
+	public boolean printCosts = false; //Temporary for debugging purposes
+
+	/**
+	 * Estimate cost of given DML program in bytes.
+	 * @param dmlProgram for which the cost is estimated
+	 * @return federated cost object with cost estimate in bytes
+	 */
+	public FederatedCost costEstimate(DMLProgram dmlProgram){
+		FederatedCost programTotalCost = new FederatedCost();
+		for ( StatementBlock stmBlock : dmlProgram.getStatementBlocks() )
+			programTotalCost.addInputTotalCost(costEstimate(stmBlock).getTotal());
+		return programTotalCost;
+	}
+
+	/**
+	 * Cost estimate in bytes of given statement block.
+	 * @param sb statement block
+	 * @return federated cost object with cost estimate in bytes
+	 */
+	private FederatedCost costEstimate(StatementBlock sb){
+		if ( sb instanceof WhileStatementBlock){
+			WhileStatementBlock whileSB = (WhileStatementBlock) sb;
+			FederatedCost whileSBCost = costEstimate(whileSB.getPredicateHops());
+			for ( Statement statement : whileSB.getStatements() ){
+				WhileStatement whileStatement = (WhileStatement) statement;
+				for ( StatementBlock bodyBlock : whileStatement.getBody() )
+					whileSBCost.addInputTotalCost(costEstimate(bodyBlock));
+			}
+			whileSBCost.addRepetitionCost(DEFAULT_ITERATION_NUMBER);
+			return whileSBCost;
+		}
+		else if ( sb instanceof IfStatementBlock){
+			//Get cost of if-block + else-block and divide by two
+			// since only one of the code blocks will be executed in the end
+			IfStatementBlock ifSB = (IfStatementBlock) sb;
+			FederatedCost ifSBCost = new FederatedCost();
+			for ( Statement statement : ifSB.getStatements() ){
+				IfStatement ifStatement = (IfStatement) statement;
+				for ( StatementBlock ifBodySB : ifStatement.getIfBody() )
+					ifSBCost.addInputTotalCost(costEstimate(ifBodySB));
+				for ( StatementBlock elseBodySB : ifStatement.getElseBody() )
+					ifSBCost.addInputTotalCost(costEstimate(elseBodySB));
+			}
+			ifSBCost.setInputTotalCost(ifSBCost.getInputTotalCost()/2);
+			ifSBCost.addInputTotalCost(costEstimate(ifSB.getPredicateHops()));
+			return ifSBCost;
+		}
+		else if ( sb instanceof ForStatementBlock){
+			// This also includes ParForStatementBlocks
+			ForStatementBlock forSB = (ForStatementBlock) sb;
+			ArrayList<Hop> predicateHops = new ArrayList<>();
+			predicateHops.add(forSB.getFromHops());
+			predicateHops.add(forSB.getToHops());
+			predicateHops.add(forSB.getIncrementHops());
+			FederatedCost forSBCost = costEstimate(predicateHops);
+			for ( Statement statement : forSB.getStatements() ){
+				ForStatement forStatement = (ForStatement) statement;
+				for ( StatementBlock forStatementBlockBody : forStatement.getBody() )
+					forSBCost.addInputTotalCost(costEstimate(forStatementBlockBody));
+			}
+			forSBCost.addRepetitionCost(forSB.getEstimateReps());
+			return forSBCost;
+		}
+		else if ( sb instanceof FunctionStatementBlock){
+			FederatedCost funcCost = addInitialInputCost(sb);
+			FunctionStatementBlock funcSB = (FunctionStatementBlock) sb;
+			for(Statement statement : funcSB.getStatements()) {
+				FunctionStatement funcStatement = (FunctionStatement) statement;
+				for ( StatementBlock funcStatementBody : funcStatement.getBody() )
+					funcCost.addInputTotalCost(costEstimate(funcStatementBody));
+			}
+			return funcCost;
+		}
+		else {
+			// StatementBlock type (no subclass)
+			return costEstimate(sb.getHops());
+		}
+	}
+
+	/**
+	 * Creates new FederatedCost object and adds all child statement block cost estimates to the object.
+	 * @param sb statement block
+	 * @return new FederatedCost estimate object with all estimates of child statement blocks added
+	 */
+	private FederatedCost addInitialInputCost(StatementBlock sb){
+		FederatedCost basicCost = new FederatedCost();
+		for ( StatementBlock childSB : sb.getDMLProg().getStatementBlocks() )
+			basicCost.addInputTotalCost(costEstimate(childSB).getTotal());
+		return basicCost;
+	}
+
+	/**
+	 * Cost estimate in bytes of given list of roots.
+	 * The individual cost estimates of the hops are summed.
+	 * @param roots list of hops
+	 * @return new FederatedCost object with sum of cost estimates of given hops
+	 */
+	private FederatedCost costEstimate(ArrayList<Hop> roots){
+		FederatedCost basicCost = new FederatedCost();
+		for ( Hop root : roots )
+			basicCost.addInputTotalCost(costEstimate(root));
+		return basicCost;
+	}
+
+	/**
+	 * Return cost estimate in bytes of Hop DAG starting from given root.
+	 * @param root of Hop DAG for which cost is estimated
+	 * @return cost estimation of Hop DAG starting from given root
+	 */
+	private FederatedCost costEstimate(Hop root){
+		if ( root.federatedCostInitialized() )
+			return root.getFederatedCost();
+		else {
+			// If no input has FOUT, the root will be processed by the coordinator
+			boolean hasFederatedInput = root.someInputFederated();
+			//the input cost is included the first time the input hop is used
+			//for additional usage, the additional cost is zero (disregarding potential read cost)
+			double inputCosts = root.getInput().stream()
+				.mapToDouble( in -> in.federatedCostInitialized() ? 0 : costEstimate(in).getTotal() )
+				.sum();
+			double inputTransferCost = hasFederatedInput ? root.getInput().stream()
+				.filter(Hop::hasLocalOutput)
+				.mapToDouble(in -> in.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+				.map(inMem -> inMem/ WORKER_NETWORK_BANDWIDTH_BYTES_PS)
+				.sum() : 0;
+			double computingCost = ComputeCost.getHOPComputeCost(root);
+			if ( hasFederatedInput ){
+				//Find the number of inputs that has FOUT set.
+				int numWorkers = (int)root.getInput().stream().filter(Hop::hasFederatedOutput).count();
+				//divide memory usage by the number of workers the computation would be split to multiplied by
+				//the number of parallel processes at each worker multiplied by the FLOPS of each process
+				//This assumes uniform workload among the workers with FOUT data involved in the operation
+				//and assumes that the degree of parallelism and compute bandwidth are equal for all workers
+				computingCost = computingCost / (numWorkers*WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+			} else computingCost = computingCost / (WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+			//Calculate output transfer cost if the operation is computed at federated workers and the output is forced to the coordinator
+			double outputTransferCost = ( root.hasLocalOutput() && hasFederatedInput ) ?
+				root.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
+			double readCost = root.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_READ_BANDWIDTH_BYTES_PS;
+
+			FederatedCost rootFedCost =
+				new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts);
+			root.setFederatedCost(rootFedCost);
+
+			if ( printCosts )
+				printCosts(root);
+
+			return rootFedCost;
+		}
+	}
+
+	/**
+	 * Prints costs and information about root for debugging purposes
+	 * @param root hop for which information is printed
+	 */
+	private static void printCosts(Hop root){
+		System.out.println("===============================");
+		System.out.println(root);
+		System.out.println("Is federated: " + root.isFederated());
+		System.out.println("Has federated output: " + root.hasFederatedOutput());
+		System.out.println(root.getText());
+		System.out.println("Pure computeCost: " + ComputeCost.getHOPComputeCost(root));
+		System.out.println("Dim1: " + root.getDim1() + " Dim2: " + root.getDim2());
+		System.out.println(root.getFederatedCost().toString());
+		System.out.println("===============================");
+	}
+}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 2e3edb0..04cdf32 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -140,6 +140,7 @@ public class ProgramRewriter
 			}
 			if ( OptimizerUtils.FEDERATED_COMPILATION ) {
 				_dagRuleSet.add( new RewriteFederatedExecution() );
+				_sbRuleSet.add( new RewriteFederatedStatementBlocks() );
 			}
 		}
 		
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
index 29cda4a..e6a92ce 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
@@ -23,9 +23,14 @@ import org.apache.commons.lang3.tuple.Pair;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
 import org.apache.sysds.hops.DataOp;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
 import org.apache.sysds.parser.DataExpression;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -36,6 +41,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
 import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
 import org.apache.sysds.runtime.lineage.LineageItem;
@@ -52,6 +58,9 @@ import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
+import java.util.Collections;
+import java.util.EnumMap;
+import java.util.Map;
 import java.util.concurrent.Future;
 
 public class RewriteFederatedExecution extends HopRewriteRule {
@@ -61,18 +70,115 @@ public class RewriteFederatedExecution extends HopRewriteRule {
 			return null;
 		for ( Hop root : roots )
 			visitHop(root);
+
+		return selectFederatedExecutionPlan(roots);
+	}
+
+	@Override public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
+		return null;
+	}
+
+	/**
+	 * Select federated execution plan for every Hop in the DAG starting from given roots.
+	 * @param roots starting point for going through the Hop DAG to update the FederatedOutput fields.
+	 * @return the list of roots with updated FederatedOutput fields.
+	 */
+	private static ArrayList<Hop> selectFederatedExecutionPlan(ArrayList<Hop> roots){
+		for (Hop root : roots){
+			root.resetVisitStatus();
+		}
+		for ( Hop root : roots ){
+			visitFedPlanHop(root);
+		}
 		return roots;
 	}
 
-	@Override
-	public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
-		if( root == null )
-			return null;
-		visitHop(root);
-		return root;
+	/**
+	 * Go through the Hop DAG and set the FederatedOutput field for each Hop from leaf to given currentHop.
+	 * @param currentHop the Hop from which the DAG is visited
+	 */
+	private static void visitFedPlanHop(Hop currentHop){
+		if ( currentHop.isVisited() )
+			return;
+		if ( currentHop.getInput() != null && currentHop.getInput().size() > 0 && !isFederatedDataOp(currentHop) ){
+			// Depth first to get to the input
+			for ( Hop input : currentHop.getInput() )
+				visitFedPlanHop(input);
+		} else if ( isFederatedDataOp(currentHop) ) {
+			// leaf federated node
+			//TODO: This will block the cases where the federated DataOp is based on input that are also federated.
+			// This means that the actual federated leaf nodes will never be reached.
+			currentHop.setFederatedOutput(FederatedOutput.FOUT);
+		}
+		if ( ( isFedInstSupportedHop(currentHop) ) ){
+			// The Hop can be FOUT or LOUT or None. Check utility of FOUT vs LOUT vs None.
+			currentHop.setFederatedOutput(getHighestUtilFedOut(currentHop));
+		}
+		else
+			currentHop.setFederatedOutput(FEDInstruction.FederatedOutput.NONE);
+		currentHop.setVisited();
+	}
+
+	/**
+	 * Returns the FederatedOutput with the highest utility out of the valid FederatedOutput values.
+	 * @param hop for which the utility is found
+	 * @return the FederatedOutput value with highest utility for the given Hop
+	 */
+	private static FederatedOutput getHighestUtilFedOut(Hop hop){
+		Map<FederatedOutput,Long> fedOutUtilMap = new EnumMap<>(FederatedOutput.class);
+		if ( isFOUTSupported(hop) )
+			fedOutUtilMap.put(FederatedOutput.FOUT, getUtilFout());
+		if ( hop.getPrivacy() == null || (hop.getPrivacy() != null && !hop.getPrivacy().hasConstraints()) )
+			fedOutUtilMap.put(FederatedOutput.LOUT, getUtilLout(hop));
+		fedOutUtilMap.put(FederatedOutput.NONE, 0L);
+
+		Map.Entry<FederatedOutput, Long> fedOutMax = Collections.max(fedOutUtilMap.entrySet(), Map.Entry.comparingByValue());
+		return fedOutMax.getKey();
 	}
-	
-	private void visitHop(Hop hop){
+
+	/**
+	 * Utility if hop is FOUT. This is a simple version where it always returns 1.
+	 * @return utility if hop is FOUT
+	 */
+	private static long getUtilFout(){
+		//TODO: Make better utility estimation
+		return 1;
+	}
+
+	/**
+	 * Utility if hop is LOUT. This is a simple version only based on dimensions.
+	 * @param hop for which utility is calculated
+	 * @return utility if hop is LOUT
+	 */
+	private static long getUtilLout(Hop hop){
+		//TODO: Make better utility estimation
+		return -(long)hop.getMemEstimate();
+	}
+
+	private static boolean isFedInstSupportedHop(Hop hop){
+
+		// Check that some input is FOUT, otherwise none of the fed instructions will run unless it is fedinit
+		if ( (!isFederatedDataOp(hop)) && hop.getInput().stream().noneMatch(Hop::hasFederatedOutput) )
+			return false;
+
+		// The following operations are supported given that the above conditions have not returned already
+		return ( hop instanceof AggBinaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp
+			|| hop instanceof AggUnaryOp || hop instanceof TernaryOp || hop instanceof DataOp );
+	}
+
+	/**
+	 * Checks to see if the associatedHop supports FOUT.
+	 * @param associatedHop for which FOUT support is checked
+	 * @return true if FOUT is supported by the associatedHop
+	 */
+	private static boolean isFOUTSupported(Hop associatedHop){
+		// If the output of AggUnaryOp is a scalar, the operation cannot be FOUT
+		if ( associatedHop instanceof AggUnaryOp )
+			return !associatedHop.isScalar();
+		return true;
+	}
+
+	private static void visitHop(Hop hop){
 		if (hop.isVisited())
 			return;
 
@@ -84,15 +190,6 @@ public class RewriteFederatedExecution extends HopRewriteRule {
 		hop.setVisited();
 	}
 
-	private static void privacyBasedHopDecision(Hop hop){
-		PrivacyPropagator.hopPropagation(hop);
-		PrivacyConstraint privacyConstraint = hop.getPrivacy();
-		if ( privacyConstraint != null && privacyConstraint.hasConstraints() )
-			hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT);
-		else if ( hop.someInputFederated() )
-			hop.setFederatedOutput(FEDInstruction.FederatedOutput.LOUT);
-	}
-
 	/**
 	 * Get privacy constraints of DataOps from federated worker,
 	 * propagate privacy constraints from input to current hop,
@@ -101,7 +198,7 @@ public class RewriteFederatedExecution extends HopRewriteRule {
 	 */
 	private static void privacyBasedHopDecisionWithFedCall(Hop hop){
 		loadFederatedPrivacyConstraints(hop);
-		privacyBasedHopDecision(hop);
+		PrivacyPropagator.hopPropagation(hop);
 	}
 
 	/**
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
new file mode 100644
index 0000000..18b36d5
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.Arrays;
+import java.util.List;
+
+public class RewriteFederatedStatementBlocks extends StatementBlockRewriteRule {
+
+	/**
+	 * Indicates if the rewrite potentially splits dags, which is used
+	 * for phase ordering of rewrites.
+	 *
+	 * @return true if dag splits are possible.
+	 */
+	@Override public boolean createsSplitDag() {
+		return false;
+	}
+
+	/**
+	 * Handle an arbitrary statement block. Specific type constraints have to be ensured
+	 * within the individual rewrites. If a rewrite does not apply to individual blocks, it
+	 * should simply return the input block.
+	 *
+	 * @param sb    statement block
+	 * @param state program rewrite status
+	 * @return list of statement blocks
+	 */
+	@Override
+	public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
+		return Arrays.asList(sb);
+	}
+
+	/**
+	 * Handle a list of statement blocks. Specific type constraints have to be ensured
+	 * within the individual rewrites. If a rewrite does not require sequence access, it
+	 * should simply return the input list of statement blocks.
+	 *
+	 * @param sbs   list of statement blocks
+	 * @param state program rewrite status
+	 * @return list of statement blocks
+	 */
+	@Override
+	public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus state) {
+		return sbs;
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index bae38a2..755287a 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -64,6 +64,7 @@ public class FEDInstructionParser extends InstructionParser
 		String2FEDInstructionType.put( "r'"     , FEDType.Reorg );
 		String2FEDInstructionType.put( "rdiag"  , FEDType.Reorg );
 		String2FEDInstructionType.put( "rshape" , FEDType.Reorg );
+		String2FEDInstructionType.put( "rev"    , FEDType.Reorg );
 
 		// Ternary Instruction Opcodes
 		String2FEDInstructionType.put( "+*" , FEDType.Ternary);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index fb0647e..2e5366e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
@@ -124,9 +125,60 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
 			new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()}, true);
 		map.execute(getTID(), fr1);
 
-		// derive new fed mapping for output
 		MatrixObject out = ec.getMatrixObject(output);
-		out.setFedMapping(in.getFedMapping().copyWithNewID(fr1.getID()));
+		deriveNewOutputFedMapping(in, out, fr1);
+	}
+
+	/**
+	 * Set output fed mapping based on federated partitioning and aggregation type.
+	 * @param in matrix object from which fed partitioning originates from
+	 * @param out matrix object holding the dimensions of the instruction output
+	 * @param fr1 federated request holding the instruction execution call
+	 */
+	private void deriveNewOutputFedMapping(MatrixObject in, MatrixObject out, FederatedRequest fr1){
+		//Get agg type
+		if ( !(instOpcode.equals("uack+") || instOpcode.equals("uark+")) )
+			throw new DMLRuntimeException("Operation " + instOpcode + " is unknown to FOUT processing");
+		boolean isColAgg = instOpcode.equals("uack+");
+		//Get partition type
+		FederationMap.FType inFtype = in.getFedMapping().getType();
+		//Get fedmap from in
+		FederationMap inputFedMapCopy = in.getFedMapping().copyWithNewID(fr1.getID());
+
+		//if partition type is row and aggregation type is row
+		//   then get row dim split from input and use as row dimension and get col dimension from output col dimension
+		//   and set FType to ROW
+		if ( inFtype.isRowPartitioned() && !isColAgg ){
+			for ( FederatedRange range : inputFedMapCopy.getFederatedRanges() )
+				range.setEndDim(1,out.getNumColumns());
+			inputFedMapCopy.setType(FederationMap.FType.ROW);
+		}
+		//if partition type is row and aggregation type is col
+		//   then get row and col dimension from out and use those dimensions for both federated workers
+		//   and set FType to PART
+		//if partition type is col and aggregation type is row
+		//   then set row and col dimension from out and use those dimensions for both federated workers
+		//   and set FType to PART
+		if ( (inFtype.isRowPartitioned() && isColAgg) || (inFtype.isColPartitioned() && !isColAgg) ){
+			for ( FederatedRange range : inputFedMapCopy.getFederatedRanges() ){
+				range.setBeginDim(0,0);
+				range.setBeginDim(1,0);
+				range.setEndDim(0,out.getNumRows());
+				range.setEndDim(1,out.getNumColumns());
+			}
+			inputFedMapCopy.setType(FederationMap.FType.PART);
+		}
+		//if partition type is col and aggregation type is col
+		//   then set row dimension to output and col dimension to in col split
+		//   and set FType to COL
+		if ( inFtype.isColPartitioned() && isColAgg ){
+			for ( FederatedRange range : inputFedMapCopy.getFederatedRanges() )
+				range.setEndDim(0,out.getNumRows());
+			inputFedMapCopy.setType(FederationMap.FType.COL);
+		}
+
+		//set out fedmap in the end
+		out.setFedMapping(inputFedMapCopy);
 	}
 
 	/**
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index 825b984..c2e7ab1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -130,8 +130,9 @@ public class AppendFEDInstruction extends BinaryFEDInstruction {
 		}
 		else {
 			throw new DMLRuntimeException("Unsupported federated append: "
-				+ (mo1.isFederated() ? mo1.getFedMapping().getType().name():"LOCAL") + " "
-				+ (mo2.isFederated() ? mo2.getFedMapping().getType().name():"LOCAL") + " " + _cbind);
+				+ " input 1 FType is " + (mo1.isFederated() ? mo1.getFedMapping().getType().name():"LOCAL")
+				+ ", input 2 FType is " + (mo2.isFederated() ? mo2.getFedMapping().getType().name():"LOCAL")
+				+ ", and column bind is " + _cbind);
 		}
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index 2a1cbb1..8795308 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -217,10 +217,10 @@ public class CtableFEDInstruction extends ComputationFEDInstruction {
 	 * @param mo2 input matrix object mo2
 	 * @return boolean indicating if the output can be kept on the federated sites
 	 */
-	private boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) {
+	private static boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) {
 		MatrixBlock mb = mo2.acquireReadAndRelease();
 		FederatedRange[] fedRanges = fedMap.getFederatedRanges(); // federated ranges of mo1
-		SortedMap<Double, Double> fedDims = new TreeMap<Double, Double>(); // <beginDim, endDim>
+		SortedMap<Double, Double> fedDims = new TreeMap<>(); // <beginDim, endDim>
 
 		// collect min and max of the corresponding slices of mo2
 		IntStream.range(0, fedRanges.length).forEach(i -> {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index f35030f..0e00faa 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -60,6 +60,9 @@ public abstract class FEDInstruction extends Instruction {
 		public boolean isForcedLocal() {
 			return this == LOUT;
 		}
+		public boolean isForced(){
+			return this == FOUT || this == LOUT;
+		}
 	}
 
 	protected final FEDType _fedType;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index 19847a2..cb3074f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -24,6 +24,7 @@ import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.Future;
 import java.util.stream.Stream;
 
 import org.apache.commons.lang3.tuple.Pair;
@@ -52,7 +53,8 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 
 public class ReorgFEDInstruction extends UnaryFEDInstruction {
-	
+	private static boolean fedoutFlagInString = false;
+
 	public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
 		super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut);
 	}
@@ -80,6 +82,7 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
 			return new ReorgFEDInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
 		}
 		else if ( opcode.equalsIgnoreCase("rev") ) {
+			fedoutFlagInString = parts.length > 3;
 			parseUnaryInstruction(str, in, out); //max 2 operands
 			return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
 		}
@@ -96,24 +99,34 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
 		if( !mo1.isFederated() )
 			throw new DMLRuntimeException("Federated Reorg: "
 				+ "Federated input expected, but invoked w/ "+mo1.isFederated());
+		if ( !( mo1.isFederated(FederationMap.FType.COL) || mo1.isFederated(FederationMap.FType.ROW)) )
+			throw new DMLRuntimeException("Federation type " + mo1.getFedMapping().getType()
+				+ " is not supported for Reorg processing");
 
 		if(instOpcode.equals("r'")) {
 			//execute transpose at federated site
 			FederatedRequest fr1 = FederationUtils.callInstruction(instString,
 				output, new CPOperand[] {input1},
 				new long[] {mo1.getFedMapping().getID()}, true);
-			mo1.getFedMapping().execute(getTID(), true, fr1);
+			if (_fedOut != null && !_fedOut.isForcedLocal()){
+				mo1.getFedMapping().execute(getTID(), true, fr1);
 
-			//drive output federated mapping
-			MatrixObject out = ec.getMatrixObject(output);
-			out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz());
-			out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose());
+				//drive output federated mapping
+				MatrixObject out = ec.getMatrixObject(output);
+				out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz());
+				out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose());
+			} else {
+				FederatedRequest getRequest = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
+				Future<FederatedResponse>[] execResponse = mo1.getFedMapping().execute(getTID(), true, fr1, getRequest);
+				ec.setMatrixOutput(output.getName(),
+					FederationUtils.bind(execResponse, mo1.isFederated(FederationMap.FType.COL)));
+			}
 		}
 		else if(instOpcode.equalsIgnoreCase("rev")) {
 			//execute transpose at federated site
 			FederatedRequest fr1 = FederationUtils.callInstruction(instString,
 				output, new CPOperand[] {input1},
-				new long[] {mo1.getFedMapping().getID()}, true);
+				new long[] {mo1.getFedMapping().getID()}, fedoutFlagInString);
 			mo1.getFedMapping().execute(getTID(), true, fr1);
 
 			if(mo1.isFederated(FederationMap.FType.ROW))
@@ -123,6 +136,11 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
 			MatrixObject out = ec.getMatrixObject(output);
 			out.getDataCharacteristics().set(mo1.getNumRows(), mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz());
 			out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+
+			if ( _fedOut != null && _fedOut.isForcedLocal() ){
+				out.acquireReadAndRelease();
+				out.getFedMapping().cleanup(getTID(), fr1.getID());
+			}
 		}
 		else if (instOpcode.equals("rdiag")) {
 			RdiagResult result;
@@ -158,6 +176,10 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
 				.set(diagFedMap.getMaxIndexInRange(0), diagFedMap.getMaxIndexInRange(1),
 					(int) mo1.getBlocksize());
 			rdiag.setFedMapping(diagFedMap);
+			if ( _fedOut != null && _fedOut.isForcedLocal() ){
+				rdiag.acquireReadAndRelease();
+				rdiag.getFedMapping().cleanup(getTID(), rdiag.getFedMapping().getID());
+			}
 		}
 	}
 
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 3d96cd9..3ac462c 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -36,6 +36,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
+import java.util.Set;
 import java.util.stream.Collectors;
 
 import org.apache.commons.io.FileUtils;
@@ -2097,6 +2098,31 @@ public abstract class AutomatedTestBase {
 		return false;
 	}
 
+	/**
+	 * Checks if given strings are all in the set of heavy hitters.
+	 * @param str opcodes for which it is checked if all are in the heavy hitters
+	 * @return true if all given strings are in the set of heavy hitters
+	 */
+	protected boolean heavyHittersContainsAllString(String... str){
+		Set<String> heavyHitters = Statistics.getCPHeavyHitterOpCodes();
+		return Arrays.stream(str).allMatch(heavyHitters::contains);
+	}
+
+	/**
+	 * Returns an array of the given opcodes which are not present in the set of heavy hitter opcodes.
+	 * @param opcodes for which it is checked if they are among the heavy hitters
+	 * @return array of opcodes not found in heavy hitters
+	 */
+	protected String[] missingHeavyHitters(String... opcodes){
+		Set<String> heavyHitters = Statistics.getCPHeavyHitterOpCodes();
+		List<String> missingHeavyHitters = new ArrayList<>();
+		for (String opcode : opcodes){
+			if ( !heavyHitters.contains(opcode) )
+				missingHeavyHitters.add(opcode);
+		}
+		return missingHeavyHitters.toArray(new String[0]);
+	}
+
 	protected boolean heavyHittersContainsString(String str, int minCount) {
 		int count = 0;
 		for(String opcode : Statistics.getCPHeavyHitterOpCodes())
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
new file mode 100644
index 0000000..0092a3a
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy.fedplanning;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.NaryOp;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.cost.FederatedCost;
+import org.apache.sysds.hops.cost.FederatedCostEstimator;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DMLTranslator;
+import org.apache.sysds.parser.LanguageException;
+import org.apache.sysds.parser.ParserFactory;
+import org.apache.sysds.parser.ParserWrapper;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Set;
+
+import static org.apache.sysds.common.Types.OpOp2.MULT;
+
+public class FederatedCostEstimatorTest extends AutomatedTestBase {
+
+	private static final String TEST_DIR = "functions/privacy/fedplanning/";
+	private static final String HOME = SCRIPT_DIR + TEST_DIR;
+	private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCostEstimatorTest.class.getSimpleName() + "/";
+	FederatedCostEstimator fedCostEstimator = new FederatedCostEstimator();
+
+	@Override
+	public void setUp() {}
+
+	@Test
+	public void simpleBinary() {
+		fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+
+		/*
+		 * HOP			Occurences		ComputeCost		ReadCost	ComputeCostFinal	ReadCostFinal
+		 * ------------------------------------------------------------------------------------------
+		 * LiteralOp	16				1				0			0.0625				0
+		 * DataGenOp	2				100				64			6.25				6.4
+		 * BinaryOp		1				100				1600		6.25				160
+		 * TOSTRING		1				1				800			0.0625				80
+		 * UnaryOp		1				1				8			0.0625				0.8
+		 */
+		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+
+		double expectedCost = computeCost + readCost;
+		runTest("BinaryCostEstimatorTest.dml", false, expectedCost);
+	}
+
+	@Test
+	public void ifElseTest(){
+		fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+		double expectedCost = ((computeCost + readCost + 0.8 + 0.0625 + 0.0625) / 2) + 0.0625 + 0.8 + 0.0625;
+		runTest("IfElseCostEstimatorTest.dml", false, expectedCost);
+	}
+
+	@Test
+	public void whileTest(){
+		fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+		double expectedCost = (computeCost + readCost + 0.0625) * fedCostEstimator.DEFAULT_ITERATION_NUMBER + 0.0625 + 0.8;
+		runTest("WhileCostEstimatorTest.dml", false, expectedCost);
+	}
+
+	@Test
+	public void forLoopTest(){
+		fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+		double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625;
+		double expectedCost = (computeCost + readCost + predicateCost) * 5;
+		runTest("ForLoopCostEstimatorTest.dml", false, expectedCost);
+	}
+
+	@Test
+	public void parForLoopTest(){
+		fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+		double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625;
+		double expectedCost = (computeCost + readCost + predicateCost) * 5;
+		runTest("ParForLoopCostEstimatorTest.dml", false, expectedCost);
+	}
+
+	@Test
+	public void functionTest(){
+		fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+		double expectedCost = (computeCost + readCost);
+		runTest("FunctionCostEstimatorTest.dml", false, expectedCost);
+	}
+
+	@Test
+	public void federatedMultiply() {
+		fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+		fedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;
+
+		double literalOpCost = 10*0.0625;
+		double naryOpCostSpecial = (0.125+2.2);
+		double naryOpCostSpecial2 = (0.25+6.4);
+		double naryOpCost = 4*(0.125+1.6);
+		double reorgOpCost = 6250+80015.2+160030.4;
+		double binaryOpMultCost = 3125+160000;
+		double aggBinaryOpCost = 125000+160015.2+160030.4+190.4;
+		double dataOpCost = 2*(6250+5.6);
+		double dataOpWriteCost = 6.25+100.3;
+
+		double expectedCost = literalOpCost + naryOpCost + naryOpCostSpecial + naryOpCostSpecial2 + reorgOpCost
+			+ binaryOpMultCost + aggBinaryOpCost + dataOpCost + dataOpWriteCost;
+		runTest("FederatedMultiplyCostEstimatorTest.dml", false, expectedCost);
+
+		double aggBinaryActualCost = hops.stream()
+			.filter(hop -> hop instanceof AggBinaryOp)
+			.mapToDouble(aggHop -> aggHop.getFederatedCost().getTotal()-aggHop.getFederatedCost().getInputTotalCost())
+			.sum();
+		Assert.assertEquals(aggBinaryOpCost, aggBinaryActualCost, 0.0001);
+
+		double writeActualCost = hops.stream()
+			.filter(hop -> hop instanceof DataOp)
+			.mapToDouble(writeHop -> writeHop.getFederatedCost().getTotal()-writeHop.getFederatedCost().getInputTotalCost())
+			.sum();
+		Assert.assertEquals(dataOpWriteCost+dataOpCost, writeActualCost, 0.0001);
+	}
+
+	Set<Hop> hops = new HashSet<>();
+
+	/**
+	 * Recursively adds the hop and its inputs to the set of hops.
+	 * @param hop root to be added to set of hops
+	 */
+	private void addHop(Hop hop){
+		hops.add(hop);
+		for(Hop inHop : hop.getInput()){
+			addHop(inHop);
+		}
+	}
+
+	/**
+	 * Sets dimensions of federated X and Y and sets binary multiplication to FOUT.
+	 * @param prog dml program where the HOPS are modified
+	 */
+	private void modifyFedouts(DMLProgram prog){
+		prog.getStatementBlocks().forEach(stmBlock -> stmBlock.getHops().forEach(this::addHop));
+		hops.forEach(hop -> {
+			if ( hop instanceof DataOp || (hop instanceof BinaryOp && ((BinaryOp) hop).getOp() == MULT ) ){
+				hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT);
+				hop.setExecType(Types.ExecType.FED);
+			} else {
+				hop.setFederatedOutput(FEDInstruction.FederatedOutput.LOUT);
+			}
+			if ( hop.getOpString().equals("Fed Y") || hop.getOpString().equals("Fed X") ){
+				hop.setDim1(10000);
+				hop.setDim2(10);
+			}
+		});
+	}
+
+	@SuppressWarnings("unused")
+	private void printHopsInfo(){
+		//LiteralOp
+		long literalCount = hops.stream().filter(hop -> hop instanceof LiteralOp).count();
+		System.out.println("LiteralOp Count: " + literalCount);
+		//NaryOp
+		long naryCount = hops.stream().filter(hop -> hop instanceof NaryOp).count();
+		System.out.println("NaryOp Count " + naryCount);
+		//ReorgOp
+		long reorgCount = hops.stream().filter(hop -> hop instanceof ReorgOp).count();
+		System.out.println("ReorgOp Count: " + reorgCount);
+		//BinaryOp
+		long binaryCount = hops.stream().filter(hop -> hop instanceof BinaryOp).count();
+		System.out.println("Binary count: " + binaryCount);
+		//AggBinaryOp
+		long aggBinaryCount = hops.stream().filter(hop -> hop instanceof AggBinaryOp).count();
+		System.out.println("AggBinaryOp Count: " + aggBinaryCount);
+		//DataOp
+		long dataOpCount = hops.stream().filter(hop -> hop instanceof DataOp).count();
+		System.out.println("DataOp Count: " + dataOpCount);
+
+		hops.stream().map(Hop::getClass).distinct().forEach(System.out::println);
+	}
+
+	private void runTest( String scriptFilename, boolean expectedException, double expectedCost ) {
+		boolean raisedException = false;
+		try
+		{
+			setTestConfig(scriptFilename);
+			String dmlScriptString = readScript(scriptFilename);
+
+			//parsing, dependency analysis and constructing hops (step 3 and 4 in DMLScript.java)
+			ParserWrapper parser = ParserFactory.createParser();
+			DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>());
+			DMLTranslator dmlt = new DMLTranslator(prog);
+			dmlt.liveVariableAnalysis(prog);
+			dmlt.validateParseTree(prog);
+			dmlt.constructHops(prog);
+			if ( scriptFilename.equals("FederatedMultiplyCostEstimatorTest.dml")){
+				modifyFedouts(prog);
+				dmlt.rewriteHopsDAG(prog);
+				hops = new HashSet<>();
+				prog.getStatementBlocks().forEach(stmBlock -> stmBlock.getHops().forEach(this::addHop));
+			}
+
+			FederatedCost actualCost = fedCostEstimator.costEstimate(prog);
+			Assert.assertEquals(expectedCost, actualCost.getTotal(), 0.0001);
+		}
+		catch(LanguageException ex) {
+			raisedException = true;
+			if(raisedException!=expectedException)
+				ex.printStackTrace();
+		}
+		catch(Exception ex2) {
+			ex2.printStackTrace();
+			throw new RuntimeException(ex2);
+		}
+
+		Assert.assertEquals("Expected exception does not match raised exception",
+			expectedException, raisedException);
+	}
+
+	private void setTestConfig(String scriptFilename) throws FileNotFoundException {
+		int index = scriptFilename.lastIndexOf(".dml");
+		String testName = scriptFilename.substring(0, index > 0 ? index : scriptFilename.length());
+		TestConfiguration testConfig = new TestConfiguration(TEST_CLASS_DIR, testName, new String[] {});
+		addTestConfiguration(testName, testConfig);
+		loadTestConfiguration(testConfig);
+
+		DMLConfig conf = new DMLConfig(getCurConfigFile().getPath());
+		ConfigurationManager.setLocalConfig(conf);
+	}
+
+	private static String readScript(String scriptFilename) throws IOException {
+		return DMLScript.readDMLScript(true, HOME + scriptFilename);
+	}
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 342a26b..e0ef884 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -35,7 +35,7 @@ import org.apache.sysds.test.TestUtils;
 import java.util.Arrays;
 import java.util.Collection;
 
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
@@ -76,34 +76,40 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
 
 	@Test
 	public void federatedMultiplyCP() {
-		federatedTwoMatricesSingleNodeTest(TEST_NAME);
+		String[] expectedHeavyHitters = new String[]{"fed_*", "fed_fedinit", "fed_r'", "fed_ba+*"};
+		federatedTwoMatricesSingleNodeTest(TEST_NAME, expectedHeavyHitters);
 	}
 
 	@Test
 	public void federatedRowSum(){
-		federatedTwoMatricesSingleNodeTest(TEST_NAME_2);
+		String[] expectedHeavyHitters = new String[]{"fed_*", "fed_r'", "fed_fedinit", "fed_ba+*", "fed_uark+"};
+		federatedTwoMatricesSingleNodeTest(TEST_NAME_2, expectedHeavyHitters);
 	}
 
 	@Test
 	public void federatedTernarySequence(){
-		federatedTwoMatricesSingleNodeTest(TEST_NAME_3);
+		String[] expectedHeavyHitters = new String[]{"fed_+*", "fed_1-*", "fed_fedinit", "fed_uak+"};
+		federatedTwoMatricesSingleNodeTest(TEST_NAME_3, expectedHeavyHitters);
 	}
 
 	@Test
 	public void federatedAggregateBinarySequence(){
 		cols = rows;
-		federatedTwoMatricesSingleNodeTest(TEST_NAME_4);
+		String[] expectedHeavyHitters = new String[]{"fed_ba+*", "fed_*", "fed_fedinit"};
+		federatedTwoMatricesSingleNodeTest(TEST_NAME_4, expectedHeavyHitters);
 	}
 
 	@Test
 	public void federatedAggregateBinaryColFedSequence(){
 		cols = rows;
-		federatedTwoMatricesSingleNodeTest(TEST_NAME_5);
+		String[] expectedHeavyHitters = new String[]{"fed_ba+*","fed_*","fed_fedinit"};
+		federatedTwoMatricesSingleNodeTest(TEST_NAME_5, expectedHeavyHitters);
 	}
 
 	@Test
 	public void federatedAggregateBinarySequence2(){
-		federatedTwoMatricesSingleNodeTest(TEST_NAME_6);
+		String[] expectedHeavyHitters = new String[]{"fed_ba+*","fed_fedinit"};
+		federatedTwoMatricesSingleNodeTest(TEST_NAME_6, expectedHeavyHitters);
 	}
 
 	private void writeStandardMatrix(String matrixName, long seed){
@@ -166,11 +172,11 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
 		}
 	}
 
-	private void federatedTwoMatricesSingleNodeTest(String testName){
-		federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName);
+	private void federatedTwoMatricesSingleNodeTest(String testName, String[] expectedHeavyHitters){
+		federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName, expectedHeavyHitters);
 	}
 
-	private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName) {
+	private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName, String[] expectedHeavyHitters) {
 		OptimizerUtils.FEDERATED_COMPILATION = true;
 		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
 		Types.ExecMode platformOld = rtplatform;
@@ -213,10 +219,9 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
 
 		// compare via files
 		compareResults(1e-9);
-		if ( testName.equals(TEST_NAME_3) )
-			assertTrue(heavyHittersContainsString("fed_+*", "fed_1-*"));
-		else
-			assertTrue(heavyHittersContainsString("fed_*", "fed_ba+*"));
+		if (!heavyHittersContainsAllString(expectedHeavyHitters))
+			fail("The following expected heavy hitters are missing: "
+				+ Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
 
 		TestUtils.shutdownThreads(t1, t2);
 
diff --git a/src/test/scripts/functions/privacy/fedplanning/BinaryCostEstimatorTest.dml b/src/test/scripts/functions/privacy/fedplanning/BinaryCostEstimatorTest.dml
new file mode 100644
index 0000000..1899614
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/BinaryCostEstimatorTest.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+a = matrix (1, rows=10, cols=10);
+b = matrix (2, rows=10, cols=10);
+c = a * b;
+print(toString(c));
diff --git a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyCostEstimatorTest.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyCostEstimatorTest.dml
new file mode 100644
index 0000000..dfeacec
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyCostEstimatorTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+ph = "placeholder"
+rows = 10000
+cols = 10
+X = federated(addresses=list(ph, ph),
+              ranges=list(list(0, 0), list(rows / 2, cols), list(rows / 2, 0), list(rows, cols)))
+Y = federated(addresses=list(ph, ph),
+              ranges=list(list(0, 0), list(rows/2, cols), list(rows / 2, 0), list(rows, cols)))
+Z0 = X * Y
+Z = t(Z0) %*% X
+write(Z, ph)
diff --git a/src/test/scripts/functions/privacy/fedplanning/ForLoopCostEstimatorTest.dml b/src/test/scripts/functions/privacy/fedplanning/ForLoopCostEstimatorTest.dml
new file mode 100644
index 0000000..b80e745
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/ForLoopCostEstimatorTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+for ( i in 1:5 ){
+    a = matrix (1, rows=10, cols=10);
+    b = matrix (2, rows=10, cols=10);
+    c = a * b;
+    print(toString(c));
+}
\ No newline at end of file
diff --git a/src/test/scripts/functions/privacy/fedplanning/FunctionCostEstimatorTest.dml b/src/test/scripts/functions/privacy/fedplanning/FunctionCostEstimatorTest.dml
new file mode 100644
index 0000000..1f0d876
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/FunctionCostEstimatorTest.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+multiplication = function (){
+    a = matrix (1, rows=10, cols=10);
+    b = matrix (2, rows=10, cols=10);
+    c = a * b;
+    print(toString(c));
+}
+multiplication();
\ No newline at end of file
diff --git a/src/test/scripts/functions/privacy/fedplanning/IfElseCostEstimatorTest.dml b/src/test/scripts/functions/privacy/fedplanning/IfElseCostEstimatorTest.dml
new file mode 100644
index 0000000..b0194e3
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/IfElseCostEstimatorTest.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ( 1 ){
+    a = matrix (1, rows=10, cols=10);
+    b = matrix (2, rows=10, cols=10);
+    c = a * b;
+    print(toString(c));
+}
+else {
+    print("No result");
+}
\ No newline at end of file
diff --git a/src/test/scripts/functions/privacy/fedplanning/ParForLoopCostEstimatorTest.dml b/src/test/scripts/functions/privacy/fedplanning/ParForLoopCostEstimatorTest.dml
new file mode 100644
index 0000000..32a49ec
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/ParForLoopCostEstimatorTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+parfor ( i in 1:5 ){
+    a = matrix (1, rows=10, cols=10);
+    b = matrix (2, rows=10, cols=10);
+    c = a * b;
+    print(toString(c));
+}
\ No newline at end of file
diff --git a/src/test/scripts/functions/privacy/fedplanning/WhileCostEstimatorTest.dml b/src/test/scripts/functions/privacy/fedplanning/WhileCostEstimatorTest.dml
new file mode 100644
index 0000000..faea01d
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/WhileCostEstimatorTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+while ( 1 ){
+    a = matrix (1, rows=10, cols=10);
+    b = matrix (2, rows=10, cols=10);
+    c = a * b;
+    print(toString(c));
+}
\ No newline at end of file