You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2022/02/24 10:06:09 UTC

[systemds] branch main updated: [SYSTEMDS-3018] Federated Cost Estimation for Repetitions

This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 37b3e93  [SYSTEMDS-3018] Federated Cost Estimation for Repetitions
37b3e93 is described below

commit 37b3e934ddc9d686d8f6ede9f689038a998ff87a
Author: sebwrede <sw...@know-center.at>
AuthorDate: Tue Feb 15 11:48:33 2022 +0100

    [SYSTEMDS-3018] Federated Cost Estimation for Repetitions
    
    This commit changes the federated plan cost estimation when while/for/if statement blocks are used.
    Closes #1547.
---
 src/main/java/org/apache/sysds/hops/Hop.java       |  22 +++
 .../java/org/apache/sysds/hops/OptimizerUtils.java |  23 +--
 .../org/apache/sysds/hops/cost/CostEstimator.java  |   4 +-
 .../org/apache/sysds/hops/cost/FederatedCost.java  |  38 ++---
 .../sysds/hops/cost/FederatedCostEstimator.java    |  62 +++----
 .../java/org/apache/sysds/hops/cost/HopRel.java    |   2 +-
 .../hops/ipa/IPAPassRewriteFederatedPlan.java      |   7 +-
 .../java/org/apache/sysds/parser/DMLProgram.java   |   6 +
 .../org/apache/sysds/parser/ForStatementBlock.java |  18 +-
 .../sysds/parser/FunctionStatementBlock.java       |   9 +
 .../org/apache/sysds/parser/IfStatementBlock.java  |  15 +-
 .../org/apache/sysds/parser/StatementBlock.java    |  33 ++++
 .../apache/sysds/parser/WhileStatementBlock.java   |  15 +-
 .../runtime/controlprogram/WhileProgramBlock.java  |   6 +-
 .../fedplanning/FederatedCostEstimatorTest.java    | 181 ++++++++++++++++-----
 15 files changed, 326 insertions(+), 115 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index 037bfa5..003492f 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -93,6 +93,7 @@ public abstract class Hop implements ParseInfo {
 	 */
 	protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
 	protected FederatedCost _federatedCost = new FederatedCost();
+	protected double repetitions = 1;
 
 	/**
 	 * Field defining if prefetch should be activated for operation.
@@ -996,6 +997,15 @@ public abstract class Hop implements ParseInfo {
 		_federatedCost = cost;
 	}
 
+	/**
+	 * Reset federated cost of this hop and all children of this hop.
+	 */
+	public void resetFederatedCost(){
+		_federatedCost = new FederatedCost();
+		for ( Hop input : getInput() )
+			input.resetFederatedCost();
+	}
+
 	public void setUpdateType(UpdateType update){
 		_updateType = update;
 	}
@@ -1539,6 +1549,18 @@ public abstract class Hop implements ParseInfo {
 		return ret;
 	}
 
+	public void updateRepetitionEstimates(double repetitions){
+		if ( !federatedCostInitialized() ){
+			this.repetitions = repetitions;
+			for ( Hop input : getInput() )
+				input.updateRepetitionEstimates(repetitions);
+		}
+	}
+
+	public double getRepetitions(){
+		return repetitions;
+	}
+
 	/**
 	 * Clones the attributes of that and copies it over to this.
 	 * 
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 47b5822..4d48df6 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -1289,17 +1289,18 @@ public class OptimizerUtils
 		if( fpb.getStatementBlock()==null )
 			return defaultValue;
 		ForStatementBlock fsb = (ForStatementBlock) fpb.getStatementBlock();
-		try {
-			HashMap<Long,Long> memo = new HashMap<>();
-			long from = rEvalSimpleLongExpression(fsb.getFromHops().getInput().get(0), memo);
-			long to = rEvalSimpleLongExpression(fsb.getToHops().getInput().get(0), memo);
-			long increment = (fsb.getIncrementHops()==null) ? (from < to) ? 1 : -1 : 
-				rEvalSimpleLongExpression(fsb.getIncrementHops().getInput().get(0), memo);
-			if( from != Long.MAX_VALUE && to != Long.MAX_VALUE && increment != Long.MAX_VALUE )
-				return (int)Math.ceil(((double)(to-from+1))/increment);
-		}
-		catch(Exception ex){}
-		return defaultValue;
+		return getNumIterations(fsb, defaultValue);
+	}
+
+	public static long getNumIterations(ForStatementBlock fsb, long defaultValue){
+		HashMap<Long,Long> memo = new HashMap<>();
+		long from = rEvalSimpleLongExpression(fsb.getFromHops().getInput().get(0), memo);
+		long to = rEvalSimpleLongExpression(fsb.getToHops().getInput().get(0), memo);
+		long increment = (fsb.getIncrementHops()==null) ? (from < to) ? 1 : -1 :
+			rEvalSimpleLongExpression(fsb.getIncrementHops().getInput().get(0), memo);
+		if( from != Long.MAX_VALUE && to != Long.MAX_VALUE && increment != Long.MAX_VALUE )
+			return (int)Math.ceil(((double)(to-from+1))/increment);
+		else return defaultValue;
 	}
 	
 	public static long getNumIterations(ForProgramBlock fpb, LocalVariableMap vars, long defaultValue) {
diff --git a/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java b/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
index 03948d4..497b807 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
@@ -116,7 +116,7 @@ public abstract class CostEstimator
 				for( ProgramBlock pb2 : tmp.getChildBlocks() )
 					ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
 			
-			ret *= getNumIterations(stats, tmp);
+			ret *= getNumIterations(tmp);
 		}
 		else if ( pb instanceof FunctionProgramBlock ) {
 			FunctionProgramBlock tmp = (FunctionProgramBlock) pb;
@@ -413,7 +413,7 @@ public abstract class CostEstimator
 		vs[2] = _unknownStats;
 	}
 		
-	private static long getNumIterations(HashMap<String,VarStats> stats, ForProgramBlock pb) {
+	private static long getNumIterations(ForProgramBlock pb) {
 		return OptimizerUtils.getNumIterations(pb, DEFAULT_NUMITER);
 	}
 
diff --git a/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java b/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
index f4f8db4..8831fdc 100644
--- a/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
@@ -30,15 +30,20 @@ public class FederatedCost {
 	protected double _outputTransferCost = 0;
 	protected double _inputTotalCost = 0;
 
+	protected double _repetitions = 1;
+	protected double _totalCost;
+
 	public FederatedCost(){}
 
 	public FederatedCost(double readCost, double inputTransferCost, double outputTransferCost,
-		double computeCost, double inputTotalCost){
+		double computeCost, double inputTotalCost, double repetitions){
 		_readCost = readCost;
 		_inputTransferCost = inputTransferCost;
 		_outputTransferCost = outputTransferCost;
 		_computeCost = computeCost;
 		_inputTotalCost = inputTotalCost;
+		_repetitions = repetitions;
+		_totalCost = calcTotal();
 	}
 
 	/**
@@ -46,15 +51,15 @@ public class FederatedCost {
 	 * @return total cost
 	 */
 	public double getTotal(){
-		return _computeCost + _readCost + _inputTransferCost + _outputTransferCost + _inputTotalCost;
+		return _totalCost;
 	}
 
-	/**
-	 * Multiply the input costs by the number of times the costs are repeated.
-	 * @param repetitionNumber number of repetitions of the costs
-	 */
-	public void addRepetitionCost(int repetitionNumber){
-		_inputTotalCost *= repetitionNumber;
+	private double calcTotal(){
+		return (_computeCost + _readCost + _inputTransferCost + _outputTransferCost) * _repetitions + _inputTotalCost;
+	}
+
+	private void updateTotal(){
+		this._totalCost = calcTotal();
 	}
 
 	/**
@@ -75,6 +80,7 @@ public class FederatedCost {
 	 */
 	public void addInputTotalCost(double additionalCost){
 		_inputTotalCost += additionalCost;
+		updateTotal();
 	}
 
 	/**
@@ -82,19 +88,7 @@ public class FederatedCost {
 	 * @param federatedCost input cost from which the total is retrieved
 	 */
 	public void addInputTotalCost(FederatedCost federatedCost){
-		_inputTotalCost += federatedCost.getTotal();
-	}
-
-	/**
-	 * Add costs of FederatedCost object to this object's current costs.
-	 * @param additionalCost object to add to this object
-	 */
-	public void addFederatedCost(FederatedCost additionalCost){
-		_readCost += additionalCost._readCost;
-		_inputTransferCost += additionalCost._inputTransferCost;
-		_outputTransferCost += additionalCost._outputTransferCost;
-		_computeCost += additionalCost._computeCost;
-		_inputTotalCost += additionalCost._inputTotalCost;
+		addInputTotalCost(federatedCost.getTotal());
 	}
 
 	@Override
@@ -110,6 +104,8 @@ public class FederatedCost {
 		builder.append(_outputTransferCost);
 		builder.append("\n inputTotalCost: ");
 		builder.append(_inputTotalCost);
+		builder.append("\n repetitions: ");
+		builder.append(_repetitions);
 		builder.append("\n total cost: ");
 		builder.append(getTotal());
 		return builder.toString();
diff --git a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
index 96a33d4..400caa9 100644
--- a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysds.hops.cost;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.ipa.MemoTable;
 import org.apache.sysds.parser.DMLProgram;
@@ -39,14 +41,13 @@ import java.util.ArrayList;
  * Cost estimator for federated executions with methods and constants for going through DML programs to estimate costs.
  */
 public class FederatedCostEstimator {
-	public int DEFAULT_MEMORY_ESTIMATE = 8;
-	public int DEFAULT_ITERATION_NUMBER = 15;
-	public double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 1024*1024*1024; //Default network bandwidth in bytes per second
-	public double WORKER_COMPUTE_BANDWIDTH_FLOPS = 2.5*1024*1024*1024; //Default compute bandwidth in FLOPS
-	public double WORKER_DEGREE_OF_PARALLELISM = 8; //Default number of parallel processes for workers
-	public double WORKER_READ_BANDWIDTH_BYTES_PS = 3.5*1024*1024*1024; //Default read bandwidth in bytes per second
+	private static final Log LOG = LogFactory.getLog(FederatedCostEstimator.class.getName());
 
-	public boolean printCosts = false; //Temporary for debugging purposes
+	public static int DEFAULT_MEMORY_ESTIMATE = 8;
+	public static double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 1024*1024*1024; //Default network bandwidth in bytes per second
+	public static double WORKER_COMPUTE_BANDWIDTH_FLOPS = 2.5*1024*1024*1024; //Default compute bandwidth in FLOPS
+	public static double WORKER_DEGREE_OF_PARALLELISM = 8; //Default number of parallel processes for workers
+	public static double WORKER_READ_BANDWIDTH_BYTES_PS = 3.5*1024*1024*1024; //Default read bandwidth in bytes per second
 
 	/**
 	 * Estimate cost of given DML program in bytes.
@@ -54,6 +55,7 @@ public class FederatedCostEstimator {
 	 * @return federated cost object with cost estimate in bytes
 	 */
 	public FederatedCost costEstimate(DMLProgram dmlProgram){
+		dmlProgram.updateRepetitionEstimates();
 		FederatedCost programTotalCost = new FederatedCost();
 		for ( StatementBlock stmBlock : dmlProgram.getStatementBlocks() )
 			programTotalCost.addInputTotalCost(costEstimate(stmBlock).getTotal());
@@ -74,12 +76,9 @@ public class FederatedCostEstimator {
 				for ( StatementBlock bodyBlock : whileStatement.getBody() )
 					whileSBCost.addInputTotalCost(costEstimate(bodyBlock));
 			}
-			whileSBCost.addRepetitionCost(DEFAULT_ITERATION_NUMBER);
 			return whileSBCost;
 		}
 		else if ( sb instanceof IfStatementBlock){
-			//Get cost of if-block + else-block and divide by two
-			// since only one of the code blocks will be executed in the end
 			IfStatementBlock ifSB = (IfStatementBlock) sb;
 			FederatedCost ifSBCost = new FederatedCost();
 			for ( Statement statement : ifSB.getStatements() ){
@@ -89,7 +88,6 @@ public class FederatedCostEstimator {
 				for ( StatementBlock elseBodySB : ifStatement.getElseBody() )
 					ifSBCost.addInputTotalCost(costEstimate(elseBodySB));
 			}
-			ifSBCost.setInputTotalCost(ifSBCost.getInputTotalCost()/2);
 			ifSBCost.addInputTotalCost(costEstimate(ifSB.getPredicateHops()));
 			return ifSBCost;
 		}
@@ -106,7 +104,6 @@ public class FederatedCostEstimator {
 				for ( StatementBlock forStatementBlockBody : forStatement.getBody() )
 					forSBCost.addInputTotalCost(costEstimate(forStatementBlockBody));
 			}
-			forSBCost.addRepetitionCost(forSB.getEstimateReps());
 			return forSBCost;
 		}
 		else if ( sb instanceof FunctionStatementBlock){
@@ -182,12 +179,13 @@ public class FederatedCostEstimator {
 				root.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
 			double readCost = root.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_READ_BANDWIDTH_BYTES_PS;
 
+			double rootRepetitions = root.getRepetitions();
 			FederatedCost rootFedCost =
-				new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts);
+				new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts, rootRepetitions);
 			root.setFederatedCost(rootFedCost);
 
-			if ( printCosts )
-				printCosts(root);
+			if ( LOG.isDebugEnabled() )
+				LOG.debug(getCostInfo(root));
 
 			return rootFedCost;
 		}
@@ -199,7 +197,7 @@ public class FederatedCostEstimator {
 	 * @param hopRelMemo memo table of HopRels for calculating input costs
 	 * @return cost estimation of Hop DAG starting from given root HopRel
 	 */
-	public FederatedCost costEstimate(HopRel root, MemoTable hopRelMemo){
+	public static FederatedCost costEstimate(HopRel root, MemoTable hopRelMemo){
 		// Check if root is in memo table.
 		if ( hopRelMemo.containsHopRel(root) ){
 			return root.getCostObject();
@@ -234,7 +232,8 @@ public class FederatedCostEstimator {
 				root.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
 			double readCost = root.hopRef.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_READ_BANDWIDTH_BYTES_PS;
 
-			return new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts);
+			double rootRepetitions = root.hopRef.getRepetitions();
+			return new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts, rootRepetitions);
 		}
 	}
 
@@ -247,7 +246,7 @@ public class FederatedCostEstimator {
 	 * @param root hopRel for which cost is estimated
 	 * @return input transfer cost estimate
 	 */
-	private double inputTransferCostEstimate(boolean hasFederatedInput, HopRel root){
+	private static double inputTransferCostEstimate(boolean hasFederatedInput, HopRel root){
 		if ( hasFederatedInput )
 			return root.inputDependency.stream()
 				.filter(input -> (root.hopRef.isFederatedDataOp()) ? input.hasFederatedOutput() : input.hasLocalOutput() )
@@ -275,18 +274,21 @@ public class FederatedCostEstimator {
 	}
 
 	/**
-	 * Prints costs and information about root for debugging purposes
-	 * @param root hop for which information is printed
+	 * Return costs and information about root for debugging purposes.
+	 * @param root hop for which information is returned
+	 * @return information about root cost
 	 */
-	private static void printCosts(Hop root){
-		System.out.println("===============================");
-		System.out.println(root);
-		System.out.println("Is federated: " + root.isFederated());
-		System.out.println("Has federated output: " + root.hasFederatedOutput());
-		System.out.println(root.getText());
-		System.out.println("Pure computeCost: " + ComputeCost.getHOPComputeCost(root));
-		System.out.println("Dim1: " + root.getDim1() + " Dim2: " + root.getDim2());
-		System.out.println(root.getFederatedCost().toString());
-		System.out.println("===============================");
+	private static String getCostInfo(Hop root){
+		String sep = System.getProperty("line.separator");
+		StringBuilder costInfo = new StringBuilder();
+		costInfo
+			.append(root).append(sep)
+			.append("Is federated: ").append(root.isFederated())
+			.append(" Has federated output: ").append(root.hasFederatedOutput())
+			.append(root.getText()).append(sep)
+			.append("Pure computeCost: " + ComputeCost.getHOPComputeCost(root))
+			.append(" Dim1: " + root.getDim1() + " Dim2: " + root.getDim2()).append(sep)
+			.append(root.getFederatedCost().toString()).append(sep);
+		return costInfo.toString();
 	}
 }
diff --git a/src/main/java/org/apache/sysds/hops/cost/HopRel.java b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
index b1cc6dd..bd5ee85 100644
--- a/src/main/java/org/apache/sysds/hops/cost/HopRel.java
+++ b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
@@ -55,7 +55,7 @@ public class HopRel {
 		hopRef = associatedHop;
 		this.fedOut = fedOut;
 		setInputDependency(hopRelMemo);
-		cost = new FederatedCostEstimator().costEstimate(this, hopRelMemo);
+		cost = FederatedCostEstimator.costEstimate(this, hopRelMemo);
 	}
 
 	/**
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
index db313af..383be42 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
@@ -69,6 +69,10 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
 	 */
 	private final static List<Hop> terminalHops = new ArrayList<>();
 
+	public List<Hop> getTerminalHops(){
+		return terminalHops;
+	}
+
 	/**
 	 * Indicates if an IPA pass is applicable for the current configuration.
 	 * The configuration depends on OptimizerUtils.FEDERATED_COMPILATION.
@@ -93,6 +97,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
 	@Override
 	public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph,
 		FunctionCallSizeInfo fcallSizes) {
+		prog.updateRepetitionEstimates();
 		rewriteStatementBlocks(prog, prog.getStatementBlocks());
 		setFinalFedouts();
 		return false;
@@ -178,7 +183,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
 	}
 
 	private ArrayList<StatementBlock> rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb) {
-		if(sb.getHops() != null && !sb.getHops().isEmpty()) {
+		if(sb.hasHops()) {
 			for(Hop sbHop : sb.getHops()) {
 				if(sbHop instanceof FunctionOp) {
 					String funcName = ((FunctionOp) sbHop).getFunctionName();
diff --git a/src/main/java/org/apache/sysds/parser/DMLProgram.java b/src/main/java/org/apache/sysds/parser/DMLProgram.java
index 498a59d..2edffa7 100644
--- a/src/main/java/org/apache/sysds/parser/DMLProgram.java
+++ b/src/main/java/org/apache/sysds/parser/DMLProgram.java
@@ -201,6 +201,12 @@ public class DMLProgram
 			throw new RuntimeException(ex);
 		}
 	}
+
+	public void updateRepetitionEstimates(){
+		for ( StatementBlock stmBlock : getStatementBlocks() ){
+			stmBlock.updateRepetitionEstimates(1);
+		}
+	}
 	
 	@Override
 	public String toString(){
diff --git a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
index 1acd1ac..b21b9b5 100644
--- a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
@@ -447,6 +447,20 @@ public class ForStatementBlock extends StatementBlock
 			}
 		}
 		
-		return 10;
+		return (int) DEFAULT_LOOP_REPETITIONS;
 	}
-}
\ No newline at end of file
+
+	@Override
+	public void updateRepetitionEstimates(double repetitions){
+		this.repetitions = repetitions * getEstimateReps();
+		_fromHops.updateRepetitionEstimates(this.repetitions);
+		_toHops.updateRepetitionEstimates(this.repetitions);
+		_incrementHops.updateRepetitionEstimates(this.repetitions);
+		for(Statement statement : getStatements()) {
+			List<StatementBlock> children = ((ForStatement) statement).getBody();
+			for ( StatementBlock stmBlock : children ){
+				stmBlock.updateRepetitionEstimates(this.repetitions);
+			}
+		}
+	}
+}
diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
index cc7ab64..ed70c69 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
@@ -258,4 +258,13 @@ public class FunctionStatementBlock extends StatementBlock implements FunctionBl
 		return ProgramConverter
 			.createDeepCopyFunctionStatementBlock(this, new HashSet<>(), new HashSet<>());
 	}
+
+	@Override
+	public void updateRepetitionEstimates(double repetitions){
+		for (Statement stm : getStatements()){
+			for (StatementBlock block : ((FunctionStatement) stm).getBody()){
+				block.updateRepetitionEstimates(repetitions);
+			}
+		}
+	}
 }
diff --git a/src/main/java/org/apache/sysds/parser/IfStatementBlock.java b/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
index 4762a14..bae78ca 100644
--- a/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
@@ -502,7 +502,20 @@ public class IfStatementBlock extends StatementBlock
 		liveInReturn.addVariables(_liveIn);
 		return liveInReturn;
 	}
-	
+
+	@Override
+	public void updateRepetitionEstimates(double repetitions){
+		this.repetitions = repetitions;
+		getPredicateHops().updateRepetitionEstimates(this.repetitions);
+		for ( Statement statement : getStatements() ){
+			IfStatement ifStatement = (IfStatement) statement;
+			double blockLevelReps = repetitions / 2;
+			for ( StatementBlock ifBodySB : ifStatement.getIfBody() )
+				ifBodySB.updateRepetitionEstimates(blockLevelReps);
+			for ( StatementBlock elseBodySB : ifStatement.getElseBody() )
+				elseBodySB.updateRepetitionEstimates(blockLevelReps);
+		}
+	}
 	
 	/////////
 	// materialized hops recompilation flags
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 4f8cd1b..6e9545c 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -29,6 +29,7 @@ import java.util.Map.Entry;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.FunctionOp;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.recompile.Recompiler;
 import org.apache.sysds.hops.rewrite.StatementBlockRewriteRule;
@@ -64,6 +65,9 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo
 	private boolean _splitDag = false;
 	private boolean _nondeterministic = false;
 
+	protected double repetitions = 1;
+	public final static double DEFAULT_LOOP_REPETITIONS = 10;
+
 	public StatementBlock() {
 		_ID = getNextSBID();
 		_name = "SB"+_ID;
@@ -1238,6 +1242,35 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo
 		return liveInReturn;
 	}
 
+	public boolean hasHops(){
+		return getHops() != null && !getHops().isEmpty();
+	}
+
+	/**
+	 * Updates the repetition estimate for this statement block
+	 * and all contained hops. FunctionStatementBlocks are loaded
+	 * from the function dictionary and repetitions are estimated
+	 * for the contained statement blocks.
+	 *
+	 * This method is overridden in the subclasses of StatementBlock.
+	 * @param repetitions estimated for this statement block
+	 */
+	public void updateRepetitionEstimates(double repetitions){
+		this.repetitions = repetitions;
+		if ( hasHops() ){
+			for ( Hop root : getHops() ){
+				// Set repetitionNum for hops recursively
+				if(root instanceof FunctionOp) {
+					String funcName = ((FunctionOp) root).getFunctionName();
+					FunctionStatementBlock sbFuncBlock = getDMLProg().getBuiltinFunctionDictionary().getFunction(funcName);
+					sbFuncBlock.updateRepetitionEstimates(repetitions);
+				}
+				else
+					root.updateRepetitionEstimates(repetitions);
+			}
+		}
+	}
+
 	///////////////////////////////////////////////////////////////
 	// validate error handling (consistent for all expressions)
 
diff --git a/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java b/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
index 7a09242..b28e682 100644
--- a/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
@@ -22,6 +22,7 @@ package org.apache.sysds.parser;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
 
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.Hop;
@@ -317,6 +318,18 @@ public class WhileStatementBlock extends StatementBlock
 		
 		return liveInReturn;
 	}
+
+	@Override
+	public void updateRepetitionEstimates(double repetitions){
+		this.repetitions = repetitions * DEFAULT_LOOP_REPETITIONS;
+		getPredicateHops().updateRepetitionEstimates(this.repetitions);
+		for(Statement statement : getStatements()) {
+			List<StatementBlock> children = ((WhileStatement)statement).getBody();
+			for ( StatementBlock stmBlock : children ){
+				stmBlock.updateRepetitionEstimates(this.repetitions);
+			}
+		}
+	}
 	
 	/////////
 	// materialized hops recompilation flags
@@ -331,4 +344,4 @@ public class WhileStatementBlock extends StatementBlock
 	public boolean requiresPredicateRecompilation() {
 		return _requiresPredicateRecompile;
 	}
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
index cc916de..4695b94 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
@@ -20,8 +20,12 @@
 package org.apache.sysds.runtime.controlprogram;
 
 import java.util.ArrayList;
+import java.util.List;
 
 import org.apache.sysds.hops.Hop;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
 import org.apache.sysds.parser.WhileStatementBlock;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ValueType;
@@ -151,4 +155,4 @@ public class WhileProgramBlock extends ProgramBlock
 	public String printBlockErrorLocation(){
 		return "ERROR: Runtime error in while program block generated from while statement block between lines " + _beginLine + " and " + _endLine + " -- ";
 	}
-}
\ No newline at end of file
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
index 906ed1f..b8ad989 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.test.functions.privacy.fedplanning;
 
+import net.jcip.annotations.NotThreadSafe;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.conf.ConfigurationManager;
@@ -32,15 +33,21 @@ import org.apache.sysds.hops.NaryOp;
 import org.apache.sysds.hops.ReorgOp;
 import org.apache.sysds.hops.cost.FederatedCost;
 import org.apache.sysds.hops.cost.FederatedCostEstimator;
+import org.apache.sysds.hops.ipa.FunctionCallGraph;
+import org.apache.sysds.hops.ipa.IPAPassRewriteFederatedPlan;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.DMLTranslator;
 import org.apache.sysds.parser.LanguageException;
 import org.apache.sysds.parser.ParserFactory;
 import org.apache.sysds.parser.ParserWrapper;
+import org.apache.sysds.parser.StatementBlock;
 import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
+import org.junit.After;
 import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
 import org.junit.Test;
 
 import java.io.FileNotFoundException;
@@ -51,6 +58,7 @@ import java.util.Set;
 
 import static org.apache.sysds.common.Types.OpOp2.MULT;
 
+@NotThreadSafe
 public class FederatedCostEstimatorTest extends AutomatedTestBase {
 
 	private static final String TEST_DIR = "functions/privacy/fedplanning/";
@@ -58,13 +66,36 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
 	private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCostEstimatorTest.class.getSimpleName() + "/";
 	FederatedCostEstimator fedCostEstimator = new FederatedCostEstimator();
 
+	private static double COMPUTE_FLOPS;
+	private static double READ_PS;
+	private static double NETWORK_PS;
+
 	@Override
 	public void setUp() {}
 
+	@BeforeClass
+	public static void storeConstants(){
+		COMPUTE_FLOPS = FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS;
+		READ_PS = FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS;
+		NETWORK_PS = FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS;
+	}
+
+	@Before
+	public void setConstants(){
+		FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
+		FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+		FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;
+	}
+
+	@After
+	public void resetConstants(){
+		FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = COMPUTE_FLOPS;
+		FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = READ_PS;
+		FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = NETWORK_PS;
+	}
+
 	@Test
 	public void simpleBinary() {
-		fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
-		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
 
 		/*
 		 * HOP			Occurences		ComputeCost		ReadCost	ComputeCostFinal	ReadCostFinal
@@ -75,70 +106,87 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
 		 * TOSTRING		1				1				800			0.0625				80
 		 * UnaryOp		1				1				8			0.0625				0.8
 		 */
-		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
-		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+		double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
 
 		double expectedCost = computeCost + readCost;
 		runTest("BinaryCostEstimatorTest.dml", false, expectedCost);
 	}
 
 	@Test
+	public void simpleBinaryHopRelTest() {
+		runHopRelTest("BinaryCostEstimatorTest.dml", false);
+	}
+
+	@Test
 	public void ifElseTest(){
-		fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
-		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
-		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+		double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
 		double expectedCost = ((computeCost + readCost + 0.8 + 0.0625 + 0.0625) / 2) + 0.0625 + 0.8 + 0.0625;
 		runTest("IfElseCostEstimatorTest.dml", false, expectedCost);
 	}
 
 	@Test
+	public void ifElseHopRelTest(){
+		runHopRelTest("IfElseCostEstimatorTest.dml", false);
+	}
+
+	@Test
 	public void whileTest(){
-		fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
-		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
-		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
-		double expectedCost = (computeCost + readCost + 0.0625) * fedCostEstimator.DEFAULT_ITERATION_NUMBER + 0.0625 + 0.8;
+		double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+		double expectedCost = (computeCost + readCost + 0.0625 + 0.0625 + 0.8) * StatementBlock.DEFAULT_LOOP_REPETITIONS;
 		runTest("WhileCostEstimatorTest.dml", false, expectedCost);
 	}
 
 	@Test
+	public void whileHopRelTest(){
+		runHopRelTest("WhileCostEstimatorTest.dml", false);
+	}
+
+	@Test
 	public void forLoopTest(){
-		fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
-		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
-		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+		double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
 		double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625;
 		double expectedCost = (computeCost + readCost + predicateCost) * 5;
 		runTest("ForLoopCostEstimatorTest.dml", false, expectedCost);
 	}
 
 	@Test
+	public void forLoopHopRelTest(){
+		runHopRelTest("ForLoopCostEstimatorTest.dml", false);
+	}
+
+	@Test
 	public void parForLoopTest(){
-		fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
-		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
-		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+		double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
 		double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625;
 		double expectedCost = (computeCost + readCost + predicateCost) * 5;
 		runTest("ParForLoopCostEstimatorTest.dml", false, expectedCost);
 	}
 
 	@Test
+	public void parForLoopHopRelTest(){
+		runHopRelTest("ParForLoopCostEstimatorTest.dml", false);
+	}
+
+	@Test
 	public void functionTest(){
-		fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
-		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
-		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+		double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
 		double expectedCost = (computeCost + readCost);
 		runTest("FunctionCostEstimatorTest.dml", false, expectedCost);
 	}
 
 	@Test
+	public void functionHopRelTest(){
+		runHopRelTest("FunctionCostEstimatorTest.dml", false);
+	}
+
+	@Test
 	public void federatedMultiply() {
-		fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
-		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-		fedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;
 
 		double literalOpCost = 10*0.0625;
 		double naryOpCostSpecial = (0.125+2.2);
@@ -224,27 +272,72 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
 		hops.stream().map(Hop::getClass).distinct().forEach(System.out::println);
 	}
 
+	private DMLProgram testSetup(String scriptFilename) throws IOException{
+		setTestConfig(scriptFilename);
+		String dmlScriptString = readScript(scriptFilename);
+
+		//parsing, dependency analysis and constructing hops (step 3 and 4 in DMLScript.java)
+		ParserWrapper parser = ParserFactory.createParser();
+		DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>());
+		DMLTranslator dmlt = new DMLTranslator(prog);
+		dmlt.liveVariableAnalysis(prog);
+		dmlt.validateParseTree(prog);
+		dmlt.constructHops(prog);
+		if ( scriptFilename.equals("FederatedMultiplyCostEstimatorTest.dml")){
+			modifyFedouts(prog);
+			dmlt.rewriteHopsDAG(prog);
+			hops = new HashSet<>();
+			prog.getStatementBlocks().forEach(stmBlock -> stmBlock.getHops().forEach(this::addHop));
+		}
+		return prog;
+	}
+
+	private void compareResults(DMLProgram prog) {
+		IPAPassRewriteFederatedPlan rewriter = new IPAPassRewriteFederatedPlan();
+		rewriter.rewriteProgram(prog, new FunctionCallGraph(prog), null);
+
+		double actualCost = 0;
+		for ( Hop root : rewriter.getTerminalHops() ){
+			actualCost += root.getFederatedCost().getTotal();
+		}
+
+
+		rewriter.getTerminalHops().forEach(Hop::resetFederatedCost);
+		fedCostEstimator = new FederatedCostEstimator();
+		double expectedCost = 0;
+		for ( Hop root : rewriter.getTerminalHops() )
+			expectedCost += fedCostEstimator.costEstimate(root).getTotal();
+		Assert.assertEquals(expectedCost, actualCost, 0.0001);
+	}
+
+	private void runHopRelTest( String scriptFilename, boolean expectedException ) {
+		boolean raisedException = false;
+		try
+		{
+			DMLProgram prog = testSetup(scriptFilename);
+			compareResults(prog);
+		}
+		catch(LanguageException ex) {
+			raisedException = true;
+			if(raisedException!=expectedException)
+				ex.printStackTrace();
+		}
+		catch(Exception ex2) {
+			ex2.printStackTrace();
+			throw new RuntimeException(ex2);
+		}
+
+		Assert.assertEquals("Expected exception does not match raised exception",
+			expectedException, raisedException);
+	}
+
 	private void runTest( String scriptFilename, boolean expectedException, double expectedCost ) {
 		boolean raisedException = false;
 		try
 		{
-			setTestConfig(scriptFilename);
-			String dmlScriptString = readScript(scriptFilename);
-
-			//parsing, dependency analysis and constructing hops (step 3 and 4 in DMLScript.java)
-			ParserWrapper parser = ParserFactory.createParser();
-			DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>());
-			DMLTranslator dmlt = new DMLTranslator(prog);
-			dmlt.liveVariableAnalysis(prog);
-			dmlt.validateParseTree(prog);
-			dmlt.constructHops(prog);
-			if ( scriptFilename.equals("FederatedMultiplyCostEstimatorTest.dml")){
-				modifyFedouts(prog);
-				dmlt.rewriteHopsDAG(prog);
-				hops = new HashSet<>();
-				prog.getStatementBlocks().forEach(stmBlock -> stmBlock.getHops().forEach(this::addHop));
-			}
+			DMLProgram prog = testSetup(scriptFilename);
 
+			fedCostEstimator = new FederatedCostEstimator();
 			FederatedCost actualCost = fedCostEstimator.costEstimate(prog);
 			Assert.assertEquals(expectedCost, actualCost.getTotal(), 0.0001);
 		}