You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2021/10/14 09:00:16 UTC

[systemds] branch master updated: [SYSTEMDS-3018] Federated Planner with Memo Table

This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 235a165  [SYSTEMDS-3018] Federated Planner with Memo Table
235a165 is described below

commit 235a16530d3e5047d7a45662a1f70fa938ead814
Author: sebwrede <sw...@know-center.at>
AuthorDate: Tue Aug 24 12:50:04 2021 +0200

    [SYSTEMDS-3018] Federated Planner with Memo Table
    
    This commit:
    (1) Change federated plan rewriter to take an entire DML program.
    (2) Add Basic HopRel.
    (3) Add HopRel cost estimator.
    
    Closes #1395.
---
 src/main/java/org/apache/sysds/hops/DataOp.java    |  16 +-
 src/main/java/org/apache/sysds/hops/Hop.java       |  20 +-
 .../sysds/hops/cost/FederatedCostEstimator.java    | 116 ++++++--
 .../java/org/apache/sysds/hops/cost/HopRel.java    | 218 ++++++++++++++
 .../sysds/hops/ipa/InterProceduralAnalysis.java    |   4 +-
 .../hops/rewrite/IPAPassRewriteFederatedPlan.java  | 314 +++++++++++++++++++++
 .../apache/sysds/hops/rewrite/ProgramRewriter.java |   1 -
 .../hops/rewrite/RewriteFederatedExecution.java    | 123 +-------
 .../rewrite/RewriteFederatedStatementBlocks.java   |  66 -----
 .../fedplanning/FederatedCostEstimatorTest.java    |  26 +-
 10 files changed, 671 insertions(+), 233 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java
index 9bc4607..548417d 100644
--- a/src/main/java/org/apache/sysds/hops/DataOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataOp.java
@@ -364,7 +364,8 @@ public class DataOp extends Hop {
 		return( _op == OpOpData.PERSISTENTREAD || _op == OpOpData.PERSISTENTWRITE );
 	}
 
-	public boolean isFederatedData(){
+	@Override
+	public boolean isFederatedDataOp(){
 		return _op == OpOpData.FEDERATED;
 	}
 
@@ -496,17 +497,10 @@ public class DataOp extends Hop {
 			
 			_etype = letype;
 		}
-		
-		return _etype;
-	}
 
-	/**
-	 * True if execution is federated, if output is federated, or if OpOpData is federated.
-	 * @return true if federated
-	 */
-	@Override
-	public boolean isFederated() {
-		return super.isFederated() || getOp() == OpOpData.FEDERATED;
+		updateETFed();
+
+		return _etype;
 	}
 
 	@Override
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index 45fc3af..a25cf10 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -865,15 +865,19 @@ public abstract class Hop implements ParseInfo {
 	}
 
 	/**
-	 * Update the execution type if input is federated and federated compilation is activated.
-	 * Federated compilation is activated in OptimizerUtils.
+	 * Update the execution type if input is federated.
 	 * This method only has an effect if FEDERATED_COMPILATION is activated.
+	 * Federated compilation is activated in OptimizerUtils.
 	 */
 	protected void updateETFed(){
-		if ( _federatedOutput.isForced() )
+		if ( someInputFederated() || isFederatedDataOp() )
 			_etype = ExecType.FED;
 	}
-	
+
+	/**
+	 * Checks if ExecType is federated.
+	 * @return true if ExecType is federated
+	 */
 	public boolean isFederated(){
 		return getExecType() == ExecType.FED;
 	}
@@ -882,6 +886,14 @@ public abstract class Hop implements ParseInfo {
 		return getInput().stream().anyMatch(Hop::hasFederatedOutput);
 	}
 
+	/**
+	 * Checks if the hop is a DataOp with federated data.
+	 * @return true if hop is a federated DataOp
+	 */
+	public boolean isFederatedDataOp(){
+		return false;
+	}
+
 	public ArrayList<Hop> getParent() {
 		return _parent;
 	}
diff --git a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
index 3e2f994..7089ed8 100644
--- a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
@@ -33,6 +33,8 @@ import org.apache.sysds.parser.WhileStatement;
 import org.apache.sysds.parser.WhileStatementBlock;
 
 import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
 
 /**
  * Cost estimator for federated executions with methods and constants for going through DML programs to estimate costs.
@@ -41,7 +43,7 @@ public class FederatedCostEstimator {
 	public int DEFAULT_MEMORY_ESTIMATE = 8;
 	public int DEFAULT_ITERATION_NUMBER = 15;
 	public double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 1024*1024*1024; //Default network bandwidth in bytes per second
-	public double WORKER_COMPUTE_BANDWITH_FLOPS = 2.5*1024*1024*1024; //Default compute bandwidth in FLOPS
+	public double WORKER_COMPUTE_BANDWIDTH_FLOPS = 2.5*1024*1024*1024; //Default compute bandwidth in FLOPS
 	public double WORKER_DEGREE_OF_PARALLELISM = 8; //Default number of parallel processes for workers
 	public double WORKER_READ_BANDWIDTH_BYTES_PS = 3.5*1024*1024*1024; //Default read bandwidth in bytes per second
 
@@ -154,34 +156,30 @@ public class FederatedCostEstimator {
 	 * @param root of Hop DAG for which cost is estimated
 	 * @return cost estimation of Hop DAG starting from given root
 	 */
-	private FederatedCost costEstimate(Hop root){
+	public FederatedCost costEstimate(Hop root){
 		if ( root.federatedCostInitialized() )
 			return root.getFederatedCost();
 		else {
 			// If no input has FOUT, the root will be processed by the coordinator
 			boolean hasFederatedInput = root.someInputFederated();
-			//the input cost is included the first time the input hop is used
-			//for additional usage, the additional cost is zero (disregarding potential read cost)
+			// The input cost is included the first time the input hop is used.
+			// For additional usage, the additional cost is zero (disregarding potential read cost).
 			double inputCosts = root.getInput().stream()
 				.mapToDouble( in -> in.federatedCostInitialized() ? 0 : costEstimate(in).getTotal() )
 				.sum();
-			double inputTransferCost = hasFederatedInput ? root.getInput().stream()
-				.filter(Hop::hasLocalOutput)
-				.mapToDouble(in -> in.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
-				.map(inMem -> inMem/ WORKER_NETWORK_BANDWIDTH_BYTES_PS)
-				.sum() : 0;
+			double inputTransferCost = inputTransferCostEstimate(hasFederatedInput, root);
 			double computingCost = ComputeCost.getHOPComputeCost(root);
 			if ( hasFederatedInput ){
-				//Find the number of inputs that has FOUT set.
+				// Find the number of inputs that has FOUT set.
 				int numWorkers = (int)root.getInput().stream().filter(Hop::hasFederatedOutput).count();
-				//divide memory usage by the number of workers the computation would be split to multiplied by
-				//the number of parallel processes at each worker multiplied by the FLOPS of each process
-				//This assumes uniform workload among the workers with FOUT data involved in the operation
-				//and assumes that the degree of parallelism and compute bandwidth are equal for all workers
-				computingCost = computingCost / (numWorkers*WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
-			} else computingCost = computingCost / (WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
-			//Calculate output transfer cost if the operation is computed at federated workers and the output is forced to the coordinator
-			double outputTransferCost = ( root.hasLocalOutput() && hasFederatedInput ) ?
+				// Divide memory usage by the number of workers the computation would be split to multiplied by
+				// the number of parallel processes at each worker multiplied by the FLOPS of each process.
+				// This assumes uniform workload among the workers with FOUT data involved in the operation
+				// and assumes that the degree of parallelism and compute bandwidth are equal for all workers
+				computingCost = computingCost / (numWorkers*WORKER_DEGREE_OF_PARALLELISM* WORKER_COMPUTE_BANDWIDTH_FLOPS);
+			} else computingCost = computingCost / (WORKER_DEGREE_OF_PARALLELISM* WORKER_COMPUTE_BANDWIDTH_FLOPS);
+			// Calculate output transfer cost if the operation is computed at federated workers and the output is forced to the coordinator
+			double outputTransferCost = ( root.hasLocalOutput() && (hasFederatedInput || root.isFederatedDataOp()) ) ?
 				root.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
 			double readCost = root.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_READ_BANDWIDTH_BYTES_PS;
 
@@ -197,6 +195,88 @@ public class FederatedCostEstimator {
 	}
 
 	/**
+	 * Return cost estimate in bytes of Hop DAG starting from given root HopRel.
+	 * @param root HopRel of Hop DAG for which cost is estimated
+	 * @param hopRelMemo memo table of HopRels for calculating input costs
+	 * @return cost estimation of Hop DAG starting from given root HopRel
+	 */
+	public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>> hopRelMemo){
+		// Check if root is in memo table.
+		if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+			&& hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut == root.fedOut) ){
+			return root.getCostObject();
+		}
+		else {
+			// If no input has FOUT, the root will be processed by the coordinator
+			boolean hasFederatedInput = root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+			// The input cost is included the first time the input hop is used.
+			// For additional usage, the additional cost is zero (disregarding potential read cost).
+			double inputCosts = root.inputDependency.stream()
+				.mapToDouble( in -> {
+					double inCost = in.existingCostPointer(root.hopRef.getHopID()) ?
+						0 : costEstimate(in, hopRelMemo).getTotal();
+					in.addCostPointer(root.hopRef.getHopID());
+					return inCost;
+				} )
+				.sum();
+			double inputTransferCost = inputTransferCostEstimate(hasFederatedInput, root);
+			double computingCost = ComputeCost.getHOPComputeCost(root.hopRef);
+			if ( hasFederatedInput ){
+				// Find the number of inputs that has FOUT set.
+				int numWorkers = (int)root.inputDependency.stream().filter(HopRel::hasFederatedOutput).count();
+				// Divide memory usage by the number of workers the computation would be split to multiplied by
+				// the number of parallel processes at each worker multiplied by the FLOPS of each process
+				// This assumes uniform workload among the workers with FOUT data involved in the operation
+				// and assumes that the degree of parallelism and compute bandwidth are equal for all workers
+				computingCost = computingCost / (numWorkers*WORKER_DEGREE_OF_PARALLELISM* WORKER_COMPUTE_BANDWIDTH_FLOPS);
+			} else computingCost = computingCost / (WORKER_DEGREE_OF_PARALLELISM* WORKER_COMPUTE_BANDWIDTH_FLOPS);
+			// Calculate output transfer cost if the operation is computed at federated workers and the output is forced to the coordinator
+			// If the root is a federated DataOp, the data is forced to the coordinator even if no input is LOUT
+			double outputTransferCost = ( root.hasLocalOutput() && (hasFederatedInput || root.hopRef.isFederatedDataOp()) ) ?
+				root.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
+			double readCost = root.hopRef.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_READ_BANDWIDTH_BYTES_PS;
+
+			return new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts);
+		}
+	}
+
+	/**
+	 * Returns input transfer cost estimate.
+	 * The input transfer cost estimate is based on the memory estimate of LOUT when some input is FOUT
+	 * except if root is a federated DataOp, since all input for this has to be at the coordinator.
+	 * When no input is FOUT, the input transfer cost is always 0.
+	 * @param hasFederatedInput true if root has any FOUT input
+	 * @param root hopRel for which cost is estimated
+	 * @return input transfer cost estimate
+	 */
+	private double inputTransferCostEstimate(boolean hasFederatedInput, HopRel root){
+		if ( hasFederatedInput )
+			return root.inputDependency.stream()
+				.filter(input -> (root.hopRef.isFederatedDataOp()) ? input.hasFederatedOutput() : input.hasLocalOutput() )
+				.mapToDouble(in -> in.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+				.sum() / WORKER_NETWORK_BANDWIDTH_BYTES_PS;
+		else return 0;
+	}
+
+	/**
+	 * Returns input transfer cost estimate.
+	 * The input transfer cost estimate is based on the memory estimate of LOUT when some input is FOUT
+	 * except if root is a federated DataOp, since all input for this has to be at the coordinator.
+	 * When no input is FOUT, the input transfer cost is always 0.
+	 * @param hasFederatedInput true if root has any FOUT input
+	 * @param root hop for which cost is estimated
+	 * @return input transfer cost estimate
+	 */
+	private double inputTransferCostEstimate(boolean hasFederatedInput, Hop root){
+		if ( hasFederatedInput )
+			return root.getInput().stream()
+				.filter(input -> (root.isFederatedDataOp()) ? input.hasFederatedOutput() : input.hasLocalOutput() )
+				.mapToDouble(in -> in.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+				.sum() / WORKER_NETWORK_BANDWIDTH_BYTES_PS;
+		else return 0;
+	}
+
+	/**
 	 * Prints costs and information about root for debugging purposes
 	 * @param root hop for which information is printed
 	 */
diff --git a/src/main/java/org/apache/sysds/hops/cost/HopRel.java b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
new file mode 100644
index 0000000..6191a6c
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * HopRel provides a representation of the relation between a hop, the cost of setting a given FederatedOutput value,
+ * and the input dependency with the given FederatedOutput value.
+ * The HopRel class is used when building and selecting an optimal federated execution plan in IPAPassRewriteFederatedPlan.
+ * The input dependency is needed to hold the valid and optimal FederatedOutput values for the inputs.
+ */
+public class HopRel {
+	protected final Hop hopRef;
+	protected final FEDInstruction.FederatedOutput fedOut;
+	protected final FederatedCost cost;
+	protected final Set<Long> costPointerSet = new HashSet<>();
+	protected final List<HopRel> inputDependency = new ArrayList<>();
+
+	/**
+	 * Constructs a HopRel with input dependency and cost estimate based on entries in hopRelMemo.
+	 * @param associatedHop hop associated with this HopRel
+	 * @param fedOut FederatedOutput value assigned to this HopRel
+	 * @param hopRelMemo memo table storing other HopRels including the inputs of associatedHop
+	 */
+	public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, Map<Long, List<HopRel>> hopRelMemo){
+		hopRef = associatedHop;
+		this.fedOut = fedOut;
+		setInputDependency(hopRelMemo);
+		cost = new FederatedCostEstimator().costEstimate(this, hopRelMemo);
+	}
+
+	/**
+	 * Adds hopID to set of hops pointing to this HopRel.
+	 * By storing the hopID it can later be determined if the cost
+	 * stored in this HopRel is already used as input cost in another HopRel.
+	 * @param hopID added to set of stored cost pointers
+	 */
+	public void addCostPointer(long hopID){
+		costPointerSet.add(hopID);
+	}
+
+	/**
+	 * Checks if another Hop is refering to this HopRel in memo table.
+	 * A reference from a HopRel with same Hop ID is allowed, so this
+	 * ID is ignored when checking references.
+	 * @param currentHopID to ignore when checking references
+	 * @return true if another Hop refers to this HopRel in memo table
+	 */
+	public boolean existingCostPointer(long currentHopID){
+		if ( costPointerSet.contains(currentHopID) )
+			return costPointerSet.size() > 1;
+		else return costPointerSet.size() > 0;
+	}
+
+	public boolean hasLocalOutput(){
+		return fedOut == FederatedOutput.LOUT;
+	}
+
+	public boolean hasFederatedOutput(){
+		return fedOut == FederatedOutput.FOUT;
+	}
+
+	public FederatedOutput getFederatedOutput(){
+		return fedOut;
+	}
+
+	public List<HopRel> getInputDependency(){
+		return inputDependency;
+	}
+
+	public Hop getHopRef(){
+		return hopRef;
+	}
+
+	/**
+	 * Returns FOUT HopRel for given hop found in hopRelMemo or returns null if HopRel not found.
+	 * @param hop to look for in hopRelMemo
+	 * @param hopRelMemo memo table storing HopRels
+	 * @return FOUT HopRel found in hopRelMemo
+	 */
+	private HopRel getFOUTHopRel(Hop hop, Map<Long, List<HopRel>> hopRelMemo){
+		return hopRelMemo.get(hop.getHopID()).stream().filter(in->in.fedOut==FederatedOutput.FOUT).findFirst().orElse(null);
+	}
+
+	/**
+	 * Get the HopRel with minimum cost for given hop
+	 * @param hopRelMemo memo table storing HopRels
+	 * @param input hop for which minimum cost HopRel is found
+	 * @return HopRel with minimum cost for given hop
+	 */
+	private HopRel getMinOfInput(Map<Long, List<HopRel>> hopRelMemo, Hop input){
+		return hopRelMemo.get(input.getHopID()).stream()
+			.min(Comparator.comparingDouble(a -> a.cost.getTotal()))
+			.orElseThrow(() -> new DMLException("No element in Memo Table found for input"));
+	}
+
+	/**
+	 * Set valid and optimal input dependency for this HopRel as a field.
+	 * @param hopRelMemo memo table storing input HopRels
+	 */
+	private void setInputDependency(Map<Long, List<HopRel>> hopRelMemo){
+		if (hopRef.getInput() != null && hopRef.getInput().size() > 0) {
+			if ( fedOut == FederatedOutput.FOUT && !hopRef.isFederatedDataOp() ) {
+				int lowestFOUTIndex = 0;
+				HopRel lowestFOUTHopRel = getFOUTHopRel(hopRef.getInput().get(0), hopRelMemo);
+				for(int i = 1; i < hopRef.getInput().size(); i++) {
+					Hop input = hopRef.getInput(i);
+					HopRel foutHopRel = getFOUTHopRel(input, hopRelMemo);
+					if(lowestFOUTHopRel == null) {
+						lowestFOUTHopRel = foutHopRel;
+						lowestFOUTIndex = i;
+					}
+					else if(foutHopRel != null) {
+						if(foutHopRel.getCost() < lowestFOUTHopRel.getCost()) {
+							lowestFOUTHopRel = foutHopRel;
+							lowestFOUTIndex = i;
+						}
+					}
+				}
+
+				HopRel[] inputHopRels = new HopRel[hopRef.getInput().size()];
+				for(int i = 0; i < hopRef.getInput().size(); i++) {
+					if(i != lowestFOUTIndex) {
+						Hop input = hopRef.getInput(i);
+						inputHopRels[i] = getMinOfInput(hopRelMemo, input);
+					}
+					else {
+						inputHopRels[i] = lowestFOUTHopRel;
+					}
+				}
+				inputDependency.addAll(Arrays.asList(inputHopRels));
+			} else {
+				inputDependency.addAll(
+					hopRef.getInput().stream()
+						.map(input -> getMinOfInput(hopRelMemo, input))
+						.collect(Collectors.toList()));
+			}
+		}
+		validateInputDependency();
+	}
+
+	/**
+	 * Throws exception if any input dependency is null.
+	 * If any of the input dependencies are null, it is not possible to build a federated execution plan.
+	 * If this null-state is not found here, an exception will be thrown at another difficult-to-debug place.
+	 */
+	private void validateInputDependency(){
+		for ( int i = 0; i < inputDependency.size(); i++){
+			if ( inputDependency.get(i) == null)
+				throw new DMLException("HopRel input number " + i + " (" + hopRef.getInput(i) + ")"
+					+ " is null for root: \n" + this);
+		}
+	}
+
+	/**
+	 * Get total cost as double
+	 * @return cost as double
+	 */
+	public double getCost(){
+		return cost.getTotal();
+	}
+
+	/**
+	 * Get cost object
+	 * @return cost object
+	 */
+	public FederatedCost getCostObject(){
+		return cost;
+	}
+
+	@Override
+	public String toString(){
+		StringBuilder strB = new StringBuilder();
+		strB.append(this.getClass().getSimpleName());
+		strB.append(" {HopID: ");
+		strB.append(hopRef.getHopID());
+		strB.append(", Opcode: ");
+		strB.append(hopRef.getOpString());
+		strB.append(", FedOut: ");
+		strB.append(fedOut);
+		strB.append(", Cost: ");
+		strB.append(cost);
+		strB.append(", Number of inputs: ");
+		strB.append(inputDependency.size());
+		strB.append("}");
+		return strB.toString();
+	}
+}
diff --git a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
index 8224192..b0597eb 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
@@ -34,6 +34,7 @@ import org.apache.sysds.hops.HopsException;
 import org.apache.sysds.hops.LiteralOp;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.hops.rewrite.IPAPassRewriteFederatedPlan;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.DMLTranslator;
 import org.apache.sysds.parser.DataIdentifier;
@@ -243,7 +244,8 @@ public class InterProceduralAnalysis
 		List<IPAPass> fpasses = Arrays.asList(
 			new IPAPassRemoveUnusedFunctions(),
 			new IPAPassCompressionWorkloadAnalysis(), // workload-aware compression
-			new IPAPassApplyStaticAndDynamicHopRewrites());  //split after compress
+			new IPAPassApplyStaticAndDynamicHopRewrites(),  //split after compress
+			new IPAPassRewriteFederatedPlan());
 		for(IPAPass pass : fpasses)
 			if( pass.isApplicable(graph2) )
 				pass.rewriteProgram(_prog, graph2, null);
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
new file mode 100644
index 0000000..cbc21cf
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
+import org.apache.sysds.hops.cost.HopRel;
+import org.apache.sysds.hops.ipa.FunctionCallGraph;
+import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
+import org.apache.sysds.hops.ipa.IPAPass;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * This rewrite generates a federated execution plan by estimating and setting costs and the FederatedOutput values of
+ * all relevant hops in the DML program.
+ * The rewrite is only applied if federated compilation is activated in OptimizerUtils.
+ */
+public class IPAPassRewriteFederatedPlan extends IPAPass {
+
+	private final static Map<Long, List<HopRel>> hopRelMemo = new HashMap<>();
+
+	/**
+	 * Indicates if an IPA pass is applicable for the current configuration.
+	 * The configuration depends on OptimizerUtils.FEDERATED_COMPILATION.
+	 *
+	 * @param fgraph function call graph
+	 * @return true if federated compilation is activated.
+	 */
+	@Override
+	public boolean isApplicable(FunctionCallGraph fgraph) {
+		return OptimizerUtils.FEDERATED_COMPILATION;
+	}
+
+	/**
+	 * Estimates cost and selects a federated execution plan
+	 * by setting the federated output value of each hop in the program.
+	 *
+	 * @param prog       dml program
+	 * @param fgraph     function call graph
+	 * @param fcallSizes function call size infos
+	 * @return false since the function call graph never has to be rebuilt
+	 */
+	@Override
+	public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
+		rewriteStatementBlocks(prog.getStatementBlocks());
+		return false;
+	}
+
+	/**
+	 * Estimates cost and selects a federated execution plan
+	 * by setting the federated output value of each hop in the statement blocks.
+	 * The method calls the contained statement blocks recursively.
+	 *
+	 * @param sbs   list of statement blocks
+	 * @return list of statement blocks with the federated output value updated for each hop
+	 */
+	public ArrayList<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs) {
+		ArrayList<StatementBlock> rewrittenStmBlocks = new ArrayList<>();
+		for ( StatementBlock stmBlock : sbs )
+			rewrittenStmBlocks.addAll(rewriteStatementBlock(stmBlock));
+		return rewrittenStmBlocks;
+	}
+
+	/**
+	 * Estimates cost and selects a federated execution plan
+	 * by setting the federated output value of each hop in the statement blocks.
+	 * The method calls the contained statement blocks recursively.
+	 *
+	 * @param sb    statement block
+	 * @return list of statement blocks with the federated output value updated for each hop
+	 */
+	public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb) {
+		if ( sb instanceof WhileStatementBlock)
+			return rewriteWhileStatementBlock((WhileStatementBlock) sb);
+		else if ( sb instanceof IfStatementBlock)
+			return rewriteIfStatementBlock((IfStatementBlock) sb);
+		else if ( sb instanceof ForStatementBlock){
+			// This also includes ParForStatementBlocks
+			return rewriteForStatementBlock((ForStatementBlock) sb);
+		}
+		else if ( sb instanceof FunctionStatementBlock)
+			return rewriteFunctionStatementBlock((FunctionStatementBlock) sb);
+		else {
+			// StatementBlock type (no subclass)
+			selectFederatedExecutionPlan(sb.getHops());
+		}
+		return new ArrayList<>(Collections.singletonList(sb));
+	}
+
+	private ArrayList<StatementBlock> rewriteWhileStatementBlock(WhileStatementBlock whileSB){
+		Hop whilePredicateHop = whileSB.getPredicateHops();
+		selectFederatedExecutionPlan(whilePredicateHop);
+		for ( Statement stm : whileSB.getStatements() ){
+			WhileStatement whileStm = (WhileStatement) stm;
+			whileStm.setBody(rewriteStatementBlocks(whileStm.getBody()));
+		}
+		return new ArrayList<>(Collections.singletonList(whileSB));
+	}
+
+	private ArrayList<StatementBlock> rewriteIfStatementBlock(IfStatementBlock ifSB){
+		selectFederatedExecutionPlan(ifSB.getPredicateHops());
+		for ( Statement statement : ifSB.getStatements() ){
+			IfStatement ifStatement = (IfStatement) statement;
+			ifStatement.setIfBody(rewriteStatementBlocks(ifStatement.getIfBody()));
+			ifStatement.setElseBody(rewriteStatementBlocks(ifStatement.getElseBody()));
+		}
+		return new ArrayList<>(Collections.singletonList(ifSB));
+	}
+
+	private ArrayList<StatementBlock> rewriteForStatementBlock(ForStatementBlock forSB){
+		selectFederatedExecutionPlan(forSB.getFromHops());
+		selectFederatedExecutionPlan(forSB.getToHops());
+		selectFederatedExecutionPlan(forSB.getIncrementHops());
+		for ( Statement statement : forSB.getStatements() ){
+			ForStatement forStatement = ((ForStatement)statement);
+			forStatement.setBody(rewriteStatementBlocks(forStatement.getBody()));
+		}
+		return new ArrayList<>(Collections.singletonList(forSB));
+	}
+
+	private ArrayList<StatementBlock> rewriteFunctionStatementBlock(FunctionStatementBlock funcSB){
+		for ( Statement statement : funcSB.getStatements() ){
+			FunctionStatement funcStm = (FunctionStatement) statement;
+			funcStm.setBody(rewriteStatementBlocks(funcStm.getBody()));
+		}
+		return new ArrayList<>(Collections.singletonList(funcSB));
+	}
+
+	/**
+	 * Sets FederatedOutput field of all hops in DAG starting from given root.
+	 * The FederatedOutput chosen for root is the minimum cost HopRel found in memo table for the given root.
+	 * The FederatedOutput values chosen for the inputs to the root are chosen based on the input dependencies.
+	 * @param root hop for which FederatedOutput needs to be set
+	 */
+	private void setFinalFedout(Hop root){
+		HopRel optimalRootHopRel = hopRelMemo.get(root.getHopID()).stream().min(Comparator.comparingDouble(HopRel::getCost))
+			.orElseThrow(() -> new DMLException("Hop root " + root + " has no feasible federated output alternatives"));
+		setFinalFedout(root, optimalRootHopRel);
+	}
+
+	/**
+	 * Update the FederatedOutput value and cost based on information stored in given rootHopRel.
+	 * @param root hop for which FederatedOutput is set
+	 * @param rootHopRel from which FederatedOutput value and cost is retrieved
+	 */
+	private void setFinalFedout(Hop root, HopRel rootHopRel){
+		updateFederatedOutput(root, rootHopRel);
+		visitInputDependency(rootHopRel);
+	}
+
+	/**
+	 * Sets FederatedOutput value for each of the inputs of rootHopRel
+	 * @param rootHopRel which has its input values updated
+	 */
+	private void visitInputDependency(HopRel rootHopRel){
+		List<HopRel> hopRelInputs = rootHopRel.getInputDependency();
+		for ( HopRel input : hopRelInputs )
+			setFinalFedout(input.getHopRef(), input);
+	}
+
+	/**
+	 * Updates FederatedOutput value and cost estimate based on updateHopRel values.
+	 * @param root which has its values updated
+	 * @param updateHopRel from which the values are retrieved
+	 */
+	private void updateFederatedOutput(Hop root, HopRel updateHopRel){
+		root.setFederatedOutput(updateHopRel.getFederatedOutput());
+		root.setFederatedCost(updateHopRel.getCostObject());
+	}
+
+	/**
+	 * Select federated execution plan for every Hop in the DAG starting from given roots.
+	 * The cost estimates of the hops are also updated when FederatedOutput is updated in the hops.
+	 * @param roots starting point for going through the Hop DAG to update the FederatedOutput fields.
+	 */
+	private void selectFederatedExecutionPlan(ArrayList<Hop> roots){
+		for ( Hop root : roots )
+			selectFederatedExecutionPlan(root);
+	}
+
+	/**
+	 * Select federated execution plan for every Hop in the DAG starting from given root.
+	 * @param root starting point for going through the Hop DAG to update the federatedOutput fields
+	 */
+	private void selectFederatedExecutionPlan(Hop root){
+		visitFedPlanHop(root);
+		setFinalFedout(root);
+	}
+
+	/**
+	 * Go through the Hop DAG and set the FederatedOutput field and cost estimate for each Hop from leaf to given currentHop.
+	 * @param currentHop the Hop from which the DAG is visited
+	 */
+	private void visitFedPlanHop(Hop currentHop){
+		// If the currentHop is in the hopRelMemo table, it means that it has been visited
+		if ( hopRelMemo.containsKey(currentHop.getHopID()) )
+			return;
+		// If the currentHop has input, then the input should be visited depth-first
+		if ( currentHop.getInput() != null && currentHop.getInput().size() > 0 ){
+			for ( Hop input : currentHop.getInput() )
+				visitFedPlanHop(input);
+		}
+		// Put FOUT, LOUT, and None HopRels into the memo table
+		ArrayList<HopRel> hopRels = new ArrayList<>();
+		if ( isFedInstSupportedHop(currentHop) ){
+			for ( FEDInstruction.FederatedOutput fedoutValue : FEDInstruction.FederatedOutput.values() )
+				if ( isFedOutSupported(currentHop, fedoutValue) )
+					hopRels.add(new HopRel(currentHop,fedoutValue, hopRelMemo));
+		}
+		if ( hopRels.isEmpty() )
+			hopRels.add(new HopRel(currentHop, FEDInstruction.FederatedOutput.NONE, hopRelMemo));
+		hopRelMemo.put(currentHop.getHopID(), hopRels);
+		currentHop.setVisited();
+	}
+
+	/**
+	 * Checks if the instructions related to the given hop supports FOUT/LOUT processing.
+	 * @param hop to check for federated support
+	 * @return true if federated instructions related to hop supports FOUT/LOUT processing
+	 */
+	private boolean isFedInstSupportedHop(Hop hop){
+		// The following operations are supported given that the above conditions have not returned already
+		return ( hop instanceof AggBinaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp
+			|| hop instanceof AggUnaryOp || hop instanceof TernaryOp || hop instanceof DataOp);
+	}
+
+	/**
+	 * Checks if the associatedHop supports the given federated output value.
+	 * @param associatedHop to check support of
+	 * @param fedOut federated output value
+	 * @return true if associatedHop supports fedOut
+	 */
+	private boolean isFedOutSupported(Hop associatedHop, FEDInstruction.FederatedOutput fedOut){
+		switch(fedOut){
+			case FOUT:
+				return isFOUTSupported(associatedHop);
+			case LOUT:
+				return isLOUTSupported(associatedHop);
+			case NONE:
+				return false;
+			default:
+				return true;
+		}
+	}
+
+	/**
+	 * Checks to see if the associatedHop supports FOUT.
+	 * @param associatedHop for which FOUT support is checked
+	 * @return true if FOUT is supported by the associatedHop
+	 */
+	private boolean isFOUTSupported(Hop associatedHop){
+		// If the output of AggUnaryOp is a scalar, the operation cannot be FOUT
+		if ( associatedHop instanceof AggUnaryOp && associatedHop.isScalar() )
+			return false;
+		// It can only be FOUT if at least one of the inputs are FOUT, except if it is a federated DataOp
+		if ( associatedHop.getInput().stream().noneMatch(
+			input -> hopRelMemo.get(input.getHopID()).stream().anyMatch(HopRel::hasFederatedOutput) )
+			&& !associatedHop.isFederatedDataOp() )
+			return false;
+		return true;
+	}
+
+	/**
+	 * Checks to see if the associatedHop supports LOUT.
+	 * It supports LOUT if the output has no privacy constraints.
+	 * @param associatedHop for which LOUT support is checked.
+	 * @return true if LOUT is supported by the associatedHop
+	 */
+	private boolean isLOUTSupported(Hop associatedHop){
+		return associatedHop.getPrivacy() == null || !associatedHop.getPrivacy().hasConstraints();
+	}
+}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 04cdf32..2e3edb0 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -140,7 +140,6 @@ public class ProgramRewriter
 			}
 			if ( OptimizerUtils.FEDERATED_COMPILATION ) {
 				_dagRuleSet.add( new RewriteFederatedExecution() );
-				_sbRuleSet.add( new RewriteFederatedStatementBlocks() );
 			}
 		}
 		
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
index e6a92ce..75fa735 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
@@ -23,14 +23,8 @@ import org.apache.commons.lang3.tuple.Pair;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.sysds.api.DMLException;
-import org.apache.sysds.hops.AggBinaryOp;
-import org.apache.sysds.hops.AggUnaryOp;
-import org.apache.sysds.hops.BinaryOp;
-import org.apache.sysds.hops.DataOp;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.LiteralOp;
-import org.apache.sysds.hops.ReorgOp;
-import org.apache.sysds.hops.TernaryOp;
 import org.apache.sysds.parser.DataExpression;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -40,8 +34,6 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
 import org.apache.sysds.runtime.instructions.cp.Data;
-import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
-import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
 import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
 import org.apache.sysds.runtime.lineage.LineageItem;
@@ -58,127 +50,24 @@ import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
-import java.util.Collections;
-import java.util.EnumMap;
-import java.util.Map;
 import java.util.concurrent.Future;
 
 public class RewriteFederatedExecution extends HopRewriteRule {
+
 	@Override
 	public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
 		if ( roots == null )
 			return null;
 		for ( Hop root : roots )
 			visitHop(root);
-
-		return selectFederatedExecutionPlan(roots);
+		return roots;
 	}
 
 	@Override public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
 		return null;
 	}
 
-	/**
-	 * Select federated execution plan for every Hop in the DAG starting from given roots.
-	 * @param roots starting point for going through the Hop DAG to update the FederatedOutput fields.
-	 * @return the list of roots with updated FederatedOutput fields.
-	 */
-	private static ArrayList<Hop> selectFederatedExecutionPlan(ArrayList<Hop> roots){
-		for (Hop root : roots){
-			root.resetVisitStatus();
-		}
-		for ( Hop root : roots ){
-			visitFedPlanHop(root);
-		}
-		return roots;
-	}
-
-	/**
-	 * Go through the Hop DAG and set the FederatedOutput field for each Hop from leaf to given currentHop.
-	 * @param currentHop the Hop from which the DAG is visited
-	 */
-	private static void visitFedPlanHop(Hop currentHop){
-		if ( currentHop.isVisited() )
-			return;
-		if ( currentHop.getInput() != null && currentHop.getInput().size() > 0 && !isFederatedDataOp(currentHop) ){
-			// Depth first to get to the input
-			for ( Hop input : currentHop.getInput() )
-				visitFedPlanHop(input);
-		} else if ( isFederatedDataOp(currentHop) ) {
-			// leaf federated node
-			//TODO: This will block the cases where the federated DataOp is based on input that are also federated.
-			// This means that the actual federated leaf nodes will never be reached.
-			currentHop.setFederatedOutput(FederatedOutput.FOUT);
-		}
-		if ( ( isFedInstSupportedHop(currentHop) ) ){
-			// The Hop can be FOUT or LOUT or None. Check utility of FOUT vs LOUT vs None.
-			currentHop.setFederatedOutput(getHighestUtilFedOut(currentHop));
-		}
-		else
-			currentHop.setFederatedOutput(FEDInstruction.FederatedOutput.NONE);
-		currentHop.setVisited();
-	}
-
-	/**
-	 * Returns the FederatedOutput with the highest utility out of the valid FederatedOutput values.
-	 * @param hop for which the utility is found
-	 * @return the FederatedOutput value with highest utility for the given Hop
-	 */
-	private static FederatedOutput getHighestUtilFedOut(Hop hop){
-		Map<FederatedOutput,Long> fedOutUtilMap = new EnumMap<>(FederatedOutput.class);
-		if ( isFOUTSupported(hop) )
-			fedOutUtilMap.put(FederatedOutput.FOUT, getUtilFout());
-		if ( hop.getPrivacy() == null || (hop.getPrivacy() != null && !hop.getPrivacy().hasConstraints()) )
-			fedOutUtilMap.put(FederatedOutput.LOUT, getUtilLout(hop));
-		fedOutUtilMap.put(FederatedOutput.NONE, 0L);
-
-		Map.Entry<FederatedOutput, Long> fedOutMax = Collections.max(fedOutUtilMap.entrySet(), Map.Entry.comparingByValue());
-		return fedOutMax.getKey();
-	}
-
-	/**
-	 * Utility if hop is FOUT. This is a simple version where it always returns 1.
-	 * @return utility if hop is FOUT
-	 */
-	private static long getUtilFout(){
-		//TODO: Make better utility estimation
-		return 1;
-	}
-
-	/**
-	 * Utility if hop is LOUT. This is a simple version only based on dimensions.
-	 * @param hop for which utility is calculated
-	 * @return utility if hop is LOUT
-	 */
-	private static long getUtilLout(Hop hop){
-		//TODO: Make better utility estimation
-		return -(long)hop.getMemEstimate();
-	}
-
-	private static boolean isFedInstSupportedHop(Hop hop){
-
-		// Check that some input is FOUT, otherwise none of the fed instructions will run unless it is fedinit
-		if ( (!isFederatedDataOp(hop)) && hop.getInput().stream().noneMatch(Hop::hasFederatedOutput) )
-			return false;
-
-		// The following operations are supported given that the above conditions have not returned already
-		return ( hop instanceof AggBinaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp
-			|| hop instanceof AggUnaryOp || hop instanceof TernaryOp || hop instanceof DataOp );
-	}
-
-	/**
-	 * Checks to see if the associatedHop supports FOUT.
-	 * @param associatedHop for which FOUT support is checked
-	 * @return true if FOUT is supported by the associatedHop
-	 */
-	private static boolean isFOUTSupported(Hop associatedHop){
-		// If the output of AggUnaryOp is a scalar, the operation cannot be FOUT
-		if ( associatedHop instanceof AggUnaryOp )
-			return !associatedHop.isScalar();
-		return true;
-	}
-
-	private static void visitHop(Hop hop){
+	private void visitHop(Hop hop){
 		if (hop.isVisited())
 			return;
 
@@ -206,7 +95,7 @@ public class RewriteFederatedExecution extends HopRewriteRule {
 	 * @hop hop for which privacy constraints are loaded
 	 */
 	private static void loadFederatedPrivacyConstraints(Hop hop){
-		if ( isFederatedDataOp(hop) && hop.getPrivacy() == null){
+		if ( hop.isFederatedDataOp() && hop.getPrivacy() == null){
 			try {
 				PrivacyConstraint privConstraint = unwrapPrivConstraint(sendPrivConstraintRequest(hop));
 				hop.setPrivacy(privConstraint);
@@ -238,10 +127,6 @@ public class RewriteFederatedExecution extends HopRewriteRule {
 		return (PrivacyConstraint) privConstraintResponse.getData()[0];
 	}
 
-	private static boolean isFederatedDataOp(Hop hop){
-		return hop instanceof DataOp && ((DataOp) hop).isFederatedData();
-	}
-
 	/**
 	 * FederatedUDF for retrieving privacy constraint of data stored in file name.
 	 */
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
deleted file mode 100644
index 18b36d5..0000000
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.hops.rewrite;
-
-import org.apache.sysds.parser.StatementBlock;
-
-import java.util.Arrays;
-import java.util.List;
-
-public class RewriteFederatedStatementBlocks extends StatementBlockRewriteRule {
-
-	/**
-	 * Indicates if the rewrite potentially splits dags, which is used
-	 * for phase ordering of rewrites.
-	 *
-	 * @return true if dag splits are possible.
-	 */
-	@Override public boolean createsSplitDag() {
-		return false;
-	}
-
-	/**
-	 * Handle an arbitrary statement block. Specific type constraints have to be ensured
-	 * within the individual rewrites. If a rewrite does not apply to individual blocks, it
-	 * should simply return the input block.
-	 *
-	 * @param sb    statement block
-	 * @param state program rewrite status
-	 * @return list of statement blocks
-	 */
-	@Override
-	public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
-		return Arrays.asList(sb);
-	}
-
-	/**
-	 * Handle a list of statement blocks. Specific type constraints have to be ensured
-	 * within the individual rewrites. If a rewrite does not require sequence access, it
-	 * should simply return the input list of statement blocks.
-	 *
-	 * @param sbs   list of statement blocks
-	 * @param state program rewrite status
-	 * @return list of statement blocks
-	 */
-	@Override
-	public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus state) {
-		return sbs;
-	}
-}
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
index 0092a3a..906ed1f 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
@@ -63,7 +63,7 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
 
 	@Test
 	public void simpleBinary() {
-		fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+		fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
 		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
 
 		/*
@@ -75,7 +75,7 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
 		 * TOSTRING		1				1				800			0.0625				80
 		 * UnaryOp		1				1				8			0.0625				0.8
 		 */
-		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
 		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
 
 		double expectedCost = computeCost + readCost;
@@ -84,9 +84,9 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
 
 	@Test
 	public void ifElseTest(){
-		fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+		fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
 		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
 		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
 		double expectedCost = ((computeCost + readCost + 0.8 + 0.0625 + 0.0625) / 2) + 0.0625 + 0.8 + 0.0625;
 		runTest("IfElseCostEstimatorTest.dml", false, expectedCost);
@@ -94,9 +94,9 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
 
 	@Test
 	public void whileTest(){
-		fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+		fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
 		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
 		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
 		double expectedCost = (computeCost + readCost + 0.0625) * fedCostEstimator.DEFAULT_ITERATION_NUMBER + 0.0625 + 0.8;
 		runTest("WhileCostEstimatorTest.dml", false, expectedCost);
@@ -104,9 +104,9 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
 
 	@Test
 	public void forLoopTest(){
-		fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+		fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
 		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
 		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
 		double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625;
 		double expectedCost = (computeCost + readCost + predicateCost) * 5;
@@ -115,9 +115,9 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
 
 	@Test
 	public void parForLoopTest(){
-		fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+		fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
 		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
 		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
 		double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625;
 		double expectedCost = (computeCost + readCost + predicateCost) * 5;
@@ -126,9 +126,9 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
 
 	@Test
 	public void functionTest(){
-		fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+		fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
 		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+		double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
 		double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
 		double expectedCost = (computeCost + readCost);
 		runTest("FunctionCostEstimatorTest.dml", false, expectedCost);
@@ -136,7 +136,7 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
 
 	@Test
 	public void federatedMultiply() {
-		fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+		fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
 		fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
 		fedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;