You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2021/10/14 09:00:16 UTC
[systemds] branch master updated: [SYSTEMDS-3018] Federated Planner
with Memo Table
This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 235a165 [SYSTEMDS-3018] Federated Planner with Memo Table
235a165 is described below
commit 235a16530d3e5047d7a45662a1f70fa938ead814
Author: sebwrede <sw...@know-center.at>
AuthorDate: Tue Aug 24 12:50:04 2021 +0200
[SYSTEMDS-3018] Federated Planner with Memo Table
This commit:
(1) Change federated plan rewriter to take an entire DML program.
(2) Add Basic HopRel.
(3) Add HopRel cost estimator.
Closes #1395.
---
src/main/java/org/apache/sysds/hops/DataOp.java | 16 +-
src/main/java/org/apache/sysds/hops/Hop.java | 20 +-
.../sysds/hops/cost/FederatedCostEstimator.java | 116 ++++++--
.../java/org/apache/sysds/hops/cost/HopRel.java | 218 ++++++++++++++
.../sysds/hops/ipa/InterProceduralAnalysis.java | 4 +-
.../hops/rewrite/IPAPassRewriteFederatedPlan.java | 314 +++++++++++++++++++++
.../apache/sysds/hops/rewrite/ProgramRewriter.java | 1 -
.../hops/rewrite/RewriteFederatedExecution.java | 123 +-------
.../rewrite/RewriteFederatedStatementBlocks.java | 66 -----
.../fedplanning/FederatedCostEstimatorTest.java | 26 +-
10 files changed, 671 insertions(+), 233 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java
index 9bc4607..548417d 100644
--- a/src/main/java/org/apache/sysds/hops/DataOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataOp.java
@@ -364,7 +364,8 @@ public class DataOp extends Hop {
return( _op == OpOpData.PERSISTENTREAD || _op == OpOpData.PERSISTENTWRITE );
}
- public boolean isFederatedData(){
+ @Override
+ public boolean isFederatedDataOp(){
return _op == OpOpData.FEDERATED;
}
@@ -496,17 +497,10 @@ public class DataOp extends Hop {
_etype = letype;
}
-
- return _etype;
- }
- /**
- * True if execution is federated, if output is federated, or if OpOpData is federated.
- * @return true if federated
- */
- @Override
- public boolean isFederated() {
- return super.isFederated() || getOp() == OpOpData.FEDERATED;
+ updateETFed();
+
+ return _etype;
}
@Override
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index 45fc3af..a25cf10 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -865,15 +865,19 @@ public abstract class Hop implements ParseInfo {
}
/**
- * Update the execution type if input is federated and federated compilation is activated.
- * Federated compilation is activated in OptimizerUtils.
+ * Update the execution type if input is federated.
* This method only has an effect if FEDERATED_COMPILATION is activated.
+ * Federated compilation is activated in OptimizerUtils.
*/
protected void updateETFed(){
- if ( _federatedOutput.isForced() )
+ if ( someInputFederated() || isFederatedDataOp() )
_etype = ExecType.FED;
}
-
+
+ /**
+ * Checks if ExecType is federated.
+ * @return true if ExecType is federated
+ */
public boolean isFederated(){
return getExecType() == ExecType.FED;
}
@@ -882,6 +886,14 @@ public abstract class Hop implements ParseInfo {
return getInput().stream().anyMatch(Hop::hasFederatedOutput);
}
+ /**
+ * Checks if the hop is a DataOp with federated data.
+ * @return true if hop is a federated DataOp
+ */
+ public boolean isFederatedDataOp(){
+ return false;
+ }
+
public ArrayList<Hop> getParent() {
return _parent;
}
diff --git a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
index 3e2f994..7089ed8 100644
--- a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
@@ -33,6 +33,8 @@ import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
/**
* Cost estimator for federated executions with methods and constants for going through DML programs to estimate costs.
@@ -41,7 +43,7 @@ public class FederatedCostEstimator {
public int DEFAULT_MEMORY_ESTIMATE = 8;
public int DEFAULT_ITERATION_NUMBER = 15;
public double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 1024*1024*1024; //Default network bandwidth in bytes per second
- public double WORKER_COMPUTE_BANDWITH_FLOPS = 2.5*1024*1024*1024; //Default compute bandwidth in FLOPS
+ public double WORKER_COMPUTE_BANDWIDTH_FLOPS = 2.5*1024*1024*1024; //Default compute bandwidth in FLOPS
public double WORKER_DEGREE_OF_PARALLELISM = 8; //Default number of parallel processes for workers
public double WORKER_READ_BANDWIDTH_BYTES_PS = 3.5*1024*1024*1024; //Default read bandwidth in bytes per second
@@ -154,34 +156,30 @@ public class FederatedCostEstimator {
* @param root of Hop DAG for which cost is estimated
* @return cost estimation of Hop DAG starting from given root
*/
- private FederatedCost costEstimate(Hop root){
+ public FederatedCost costEstimate(Hop root){
if ( root.federatedCostInitialized() )
return root.getFederatedCost();
else {
// If no input has FOUT, the root will be processed by the coordinator
boolean hasFederatedInput = root.someInputFederated();
- //the input cost is included the first time the input hop is used
- //for additional usage, the additional cost is zero (disregarding potential read cost)
+ // The input cost is included the first time the input hop is used.
+ // For additional usage, the additional cost is zero (disregarding potential read cost).
double inputCosts = root.getInput().stream()
.mapToDouble( in -> in.federatedCostInitialized() ? 0 : costEstimate(in).getTotal() )
.sum();
- double inputTransferCost = hasFederatedInput ? root.getInput().stream()
- .filter(Hop::hasLocalOutput)
- .mapToDouble(in -> in.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
- .map(inMem -> inMem/ WORKER_NETWORK_BANDWIDTH_BYTES_PS)
- .sum() : 0;
+ double inputTransferCost = inputTransferCostEstimate(hasFederatedInput, root);
double computingCost = ComputeCost.getHOPComputeCost(root);
if ( hasFederatedInput ){
- //Find the number of inputs that has FOUT set.
+ // Find the number of inputs that has FOUT set.
int numWorkers = (int)root.getInput().stream().filter(Hop::hasFederatedOutput).count();
- //divide memory usage by the number of workers the computation would be split to multiplied by
- //the number of parallel processes at each worker multiplied by the FLOPS of each process
- //This assumes uniform workload among the workers with FOUT data involved in the operation
- //and assumes that the degree of parallelism and compute bandwidth are equal for all workers
- computingCost = computingCost / (numWorkers*WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
- } else computingCost = computingCost / (WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
- //Calculate output transfer cost if the operation is computed at federated workers and the output is forced to the coordinator
- double outputTransferCost = ( root.hasLocalOutput() && hasFederatedInput ) ?
+ // Divide memory usage by the number of workers the computation would be split to multiplied by
+ // the number of parallel processes at each worker multiplied by the FLOPS of each process.
+ // This assumes uniform workload among the workers with FOUT data involved in the operation
+ // and assumes that the degree of parallelism and compute bandwidth are equal for all workers
+ computingCost = computingCost / (numWorkers*WORKER_DEGREE_OF_PARALLELISM* WORKER_COMPUTE_BANDWIDTH_FLOPS);
+ } else computingCost = computingCost / (WORKER_DEGREE_OF_PARALLELISM* WORKER_COMPUTE_BANDWIDTH_FLOPS);
+ // Calculate output transfer cost if the operation is computed at federated workers and the output is forced to the coordinator
+ double outputTransferCost = ( root.hasLocalOutput() && (hasFederatedInput || root.isFederatedDataOp()) ) ?
root.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
double readCost = root.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_READ_BANDWIDTH_BYTES_PS;
@@ -197,6 +195,88 @@ public class FederatedCostEstimator {
}
/**
+ * Return cost estimate in bytes of Hop DAG starting from given root HopRel.
+ * @param root HopRel of Hop DAG for which cost is estimated
+ * @param hopRelMemo memo table of HopRels for calculating input costs
+ * @return cost estimation of Hop DAG starting from given root HopRel
+ */
+ public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>> hopRelMemo){
+ // Check if root is in memo table.
+ if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+ && hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut == root.fedOut) ){
+ return root.getCostObject();
+ }
+ else {
+ // If no input has FOUT, the root will be processed by the coordinator
+ boolean hasFederatedInput = root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+ // The input cost is included the first time the input hop is used.
+ // For additional usage, the additional cost is zero (disregarding potential read cost).
+ double inputCosts = root.inputDependency.stream()
+ .mapToDouble( in -> {
+ double inCost = in.existingCostPointer(root.hopRef.getHopID()) ?
+ 0 : costEstimate(in, hopRelMemo).getTotal();
+ in.addCostPointer(root.hopRef.getHopID());
+ return inCost;
+ } )
+ .sum();
+ double inputTransferCost = inputTransferCostEstimate(hasFederatedInput, root);
+ double computingCost = ComputeCost.getHOPComputeCost(root.hopRef);
+ if ( hasFederatedInput ){
+ // Find the number of inputs that has FOUT set.
+ int numWorkers = (int)root.inputDependency.stream().filter(HopRel::hasFederatedOutput).count();
+ // Divide memory usage by the number of workers the computation would be split to multiplied by
+ // the number of parallel processes at each worker multiplied by the FLOPS of each process
+ // This assumes uniform workload among the workers with FOUT data involved in the operation
+ // and assumes that the degree of parallelism and compute bandwidth are equal for all workers
+ computingCost = computingCost / (numWorkers*WORKER_DEGREE_OF_PARALLELISM* WORKER_COMPUTE_BANDWIDTH_FLOPS);
+ } else computingCost = computingCost / (WORKER_DEGREE_OF_PARALLELISM* WORKER_COMPUTE_BANDWIDTH_FLOPS);
+ // Calculate output transfer cost if the operation is computed at federated workers and the output is forced to the coordinator
+ // If the root is a federated DataOp, the data is forced to the coordinator even if no input is LOUT
+ double outputTransferCost = ( root.hasLocalOutput() && (hasFederatedInput || root.hopRef.isFederatedDataOp()) ) ?
+ root.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
+ double readCost = root.hopRef.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_READ_BANDWIDTH_BYTES_PS;
+
+ return new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts);
+ }
+ }
+
+ /**
+ * Returns input transfer cost estimate.
+ * The input transfer cost estimate is based on the memory estimate of LOUT when some input is FOUT
+ * except if root is a federated DataOp, since all input for this has to be at the coordinator.
+ * When no input is FOUT, the input transfer cost is always 0.
+ * @param hasFederatedInput true if root has any FOUT input
+ * @param root hopRel for which cost is estimated
+ * @return input transfer cost estimate
+ */
+ private double inputTransferCostEstimate(boolean hasFederatedInput, HopRel root){
+ if ( hasFederatedInput )
+ return root.inputDependency.stream()
+ .filter(input -> (root.hopRef.isFederatedDataOp()) ? input.hasFederatedOutput() : input.hasLocalOutput() )
+ .mapToDouble(in -> in.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+ .sum() / WORKER_NETWORK_BANDWIDTH_BYTES_PS;
+ else return 0;
+ }
+
+ /**
+ * Returns input transfer cost estimate.
+ * The input transfer cost estimate is based on the memory estimate of LOUT when some input is FOUT
+ * except if root is a federated DataOp, since all input for this has to be at the coordinator.
+ * When no input is FOUT, the input transfer cost is always 0.
+ * @param hasFederatedInput true if root has any FOUT input
+ * @param root hop for which cost is estimated
+ * @return input transfer cost estimate
+ */
+ private double inputTransferCostEstimate(boolean hasFederatedInput, Hop root){
+ if ( hasFederatedInput )
+ return root.getInput().stream()
+ .filter(input -> (root.isFederatedDataOp()) ? input.hasFederatedOutput() : input.hasLocalOutput() )
+ .mapToDouble(in -> in.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+ .sum() / WORKER_NETWORK_BANDWIDTH_BYTES_PS;
+ else return 0;
+ }
+
+ /**
* Prints costs and information about root for debugging purposes
* @param root hop for which information is printed
*/
diff --git a/src/main/java/org/apache/sysds/hops/cost/HopRel.java b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
new file mode 100644
index 0000000..6191a6c
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * HopRel provides a representation of the relation between a hop, the cost of setting a given FederatedOutput value,
+ * and the input dependency with the given FederatedOutput value.
+ * The HopRel class is used when building and selecting an optimal federated execution plan in IPAPassRewriteFederatedPlan.
+ * The input dependency is needed to hold the valid and optimal FederatedOutput values for the inputs.
+ */
+public class HopRel {
+ protected final Hop hopRef;
+ protected final FEDInstruction.FederatedOutput fedOut;
+ protected final FederatedCost cost;
+ protected final Set<Long> costPointerSet = new HashSet<>();
+ protected final List<HopRel> inputDependency = new ArrayList<>();
+
+ /**
+ * Constructs a HopRel with input dependency and cost estimate based on entries in hopRelMemo.
+ * @param associatedHop hop associated with this HopRel
+ * @param fedOut FederatedOutput value assigned to this HopRel
+ * @param hopRelMemo memo table storing other HopRels including the inputs of associatedHop
+ */
+ public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, Map<Long, List<HopRel>> hopRelMemo){
+ hopRef = associatedHop;
+ this.fedOut = fedOut;
+ setInputDependency(hopRelMemo);
+ cost = new FederatedCostEstimator().costEstimate(this, hopRelMemo);
+ }
+
+ /**
+ * Adds hopID to set of hops pointing to this HopRel.
+ * By storing the hopID it can later be determined if the cost
+ * stored in this HopRel is already used as input cost in another HopRel.
+ * @param hopID added to set of stored cost pointers
+ */
+ public void addCostPointer(long hopID){
+ costPointerSet.add(hopID);
+ }
+
+ /**
+ * Checks if another Hop is refering to this HopRel in memo table.
+ * A reference from a HopRel with same Hop ID is allowed, so this
+ * ID is ignored when checking references.
+ * @param currentHopID to ignore when checking references
+ * @return true if another Hop refers to this HopRel in memo table
+ */
+ public boolean existingCostPointer(long currentHopID){
+ if ( costPointerSet.contains(currentHopID) )
+ return costPointerSet.size() > 1;
+ else return costPointerSet.size() > 0;
+ }
+
+ public boolean hasLocalOutput(){
+ return fedOut == FederatedOutput.LOUT;
+ }
+
+ public boolean hasFederatedOutput(){
+ return fedOut == FederatedOutput.FOUT;
+ }
+
+ public FederatedOutput getFederatedOutput(){
+ return fedOut;
+ }
+
+ public List<HopRel> getInputDependency(){
+ return inputDependency;
+ }
+
+ public Hop getHopRef(){
+ return hopRef;
+ }
+
+ /**
+ * Returns FOUT HopRel for given hop found in hopRelMemo or returns null if HopRel not found.
+ * @param hop to look for in hopRelMemo
+ * @param hopRelMemo memo table storing HopRels
+ * @return FOUT HopRel found in hopRelMemo
+ */
+ private HopRel getFOUTHopRel(Hop hop, Map<Long, List<HopRel>> hopRelMemo){
+ return hopRelMemo.get(hop.getHopID()).stream().filter(in->in.fedOut==FederatedOutput.FOUT).findFirst().orElse(null);
+ }
+
+ /**
+ * Get the HopRel with minimum cost for given hop
+ * @param hopRelMemo memo table storing HopRels
+ * @param input hop for which minimum cost HopRel is found
+ * @return HopRel with minimum cost for given hop
+ */
+ private HopRel getMinOfInput(Map<Long, List<HopRel>> hopRelMemo, Hop input){
+ return hopRelMemo.get(input.getHopID()).stream()
+ .min(Comparator.comparingDouble(a -> a.cost.getTotal()))
+ .orElseThrow(() -> new DMLException("No element in Memo Table found for input"));
+ }
+
+ /**
+ * Set valid and optimal input dependency for this HopRel as a field.
+ * @param hopRelMemo memo table storing input HopRels
+ */
+ private void setInputDependency(Map<Long, List<HopRel>> hopRelMemo){
+ if (hopRef.getInput() != null && hopRef.getInput().size() > 0) {
+ if ( fedOut == FederatedOutput.FOUT && !hopRef.isFederatedDataOp() ) {
+ int lowestFOUTIndex = 0;
+ HopRel lowestFOUTHopRel = getFOUTHopRel(hopRef.getInput().get(0), hopRelMemo);
+ for(int i = 1; i < hopRef.getInput().size(); i++) {
+ Hop input = hopRef.getInput(i);
+ HopRel foutHopRel = getFOUTHopRel(input, hopRelMemo);
+ if(lowestFOUTHopRel == null) {
+ lowestFOUTHopRel = foutHopRel;
+ lowestFOUTIndex = i;
+ }
+ else if(foutHopRel != null) {
+ if(foutHopRel.getCost() < lowestFOUTHopRel.getCost()) {
+ lowestFOUTHopRel = foutHopRel;
+ lowestFOUTIndex = i;
+ }
+ }
+ }
+
+ HopRel[] inputHopRels = new HopRel[hopRef.getInput().size()];
+ for(int i = 0; i < hopRef.getInput().size(); i++) {
+ if(i != lowestFOUTIndex) {
+ Hop input = hopRef.getInput(i);
+ inputHopRels[i] = getMinOfInput(hopRelMemo, input);
+ }
+ else {
+ inputHopRels[i] = lowestFOUTHopRel;
+ }
+ }
+ inputDependency.addAll(Arrays.asList(inputHopRels));
+ } else {
+ inputDependency.addAll(
+ hopRef.getInput().stream()
+ .map(input -> getMinOfInput(hopRelMemo, input))
+ .collect(Collectors.toList()));
+ }
+ }
+ validateInputDependency();
+ }
+
+ /**
+ * Throws exception if any input dependency is null.
+ * If any of the input dependencies are null, it is not possible to build a federated execution plan.
+ * If this null-state is not found here, an exception will be thrown at another difficult-to-debug place.
+ */
+ private void validateInputDependency(){
+ for ( int i = 0; i < inputDependency.size(); i++){
+ if ( inputDependency.get(i) == null)
+ throw new DMLException("HopRel input number " + i + " (" + hopRef.getInput(i) + ")"
+ + " is null for root: \n" + this);
+ }
+ }
+
+ /**
+ * Get total cost as double
+ * @return cost as double
+ */
+ public double getCost(){
+ return cost.getTotal();
+ }
+
+ /**
+ * Get cost object
+ * @return cost object
+ */
+ public FederatedCost getCostObject(){
+ return cost;
+ }
+
+ @Override
+ public String toString(){
+ StringBuilder strB = new StringBuilder();
+ strB.append(this.getClass().getSimpleName());
+ strB.append(" {HopID: ");
+ strB.append(hopRef.getHopID());
+ strB.append(", Opcode: ");
+ strB.append(hopRef.getOpString());
+ strB.append(", FedOut: ");
+ strB.append(fedOut);
+ strB.append(", Cost: ");
+ strB.append(cost);
+ strB.append(", Number of inputs: ");
+ strB.append(inputDependency.size());
+ strB.append("}");
+ return strB.toString();
+ }
+}
diff --git a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
index 8224192..b0597eb 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
@@ -34,6 +34,7 @@ import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.hops.rewrite.IPAPassRewriteFederatedPlan;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
import org.apache.sysds.parser.DataIdentifier;
@@ -243,7 +244,8 @@ public class InterProceduralAnalysis
List<IPAPass> fpasses = Arrays.asList(
new IPAPassRemoveUnusedFunctions(),
new IPAPassCompressionWorkloadAnalysis(), // workload-aware compression
- new IPAPassApplyStaticAndDynamicHopRewrites()); //split after compress
+ new IPAPassApplyStaticAndDynamicHopRewrites(), //split after compress
+ new IPAPassRewriteFederatedPlan());
for(IPAPass pass : fpasses)
if( pass.isApplicable(graph2) )
pass.rewriteProgram(_prog, graph2, null);
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
new file mode 100644
index 0000000..cbc21cf
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
+import org.apache.sysds.hops.cost.HopRel;
+import org.apache.sysds.hops.ipa.FunctionCallGraph;
+import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
+import org.apache.sysds.hops.ipa.IPAPass;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * This rewrite generates a federated execution plan by estimating and setting costs and the FederatedOutput values of
+ * all relevant hops in the DML program.
+ * The rewrite is only applied if federated compilation is activated in OptimizerUtils.
+ */
+public class IPAPassRewriteFederatedPlan extends IPAPass {
+
+ private final static Map<Long, List<HopRel>> hopRelMemo = new HashMap<>();
+
+ /**
+ * Indicates if an IPA pass is applicable for the current configuration.
+ * The configuration depends on OptimizerUtils.FEDERATED_COMPILATION.
+ *
+ * @param fgraph function call graph
+ * @return true if federated compilation is activated.
+ */
+ @Override
+ public boolean isApplicable(FunctionCallGraph fgraph) {
+ return OptimizerUtils.FEDERATED_COMPILATION;
+ }
+
+ /**
+ * Estimates cost and selects a federated execution plan
+ * by setting the federated output value of each hop in the program.
+ *
+ * @param prog dml program
+ * @param fgraph function call graph
+ * @param fcallSizes function call size infos
+ * @return false since the function call graph never has to be rebuilt
+ */
+ @Override
+ public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
+ rewriteStatementBlocks(prog.getStatementBlocks());
+ return false;
+ }
+
+ /**
+ * Estimates cost and selects a federated execution plan
+ * by setting the federated output value of each hop in the statement blocks.
+ * The method calls the contained statement blocks recursively.
+ *
+ * @param sbs list of statement blocks
+ * @return list of statement blocks with the federated output value updated for each hop
+ */
+ public ArrayList<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs) {
+ ArrayList<StatementBlock> rewrittenStmBlocks = new ArrayList<>();
+ for ( StatementBlock stmBlock : sbs )
+ rewrittenStmBlocks.addAll(rewriteStatementBlock(stmBlock));
+ return rewrittenStmBlocks;
+ }
+
+ /**
+ * Estimates cost and selects a federated execution plan
+ * by setting the federated output value of each hop in the statement blocks.
+ * The method calls the contained statement blocks recursively.
+ *
+ * @param sb statement block
+ * @return list of statement blocks with the federated output value updated for each hop
+ */
+ public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb) {
+ if ( sb instanceof WhileStatementBlock)
+ return rewriteWhileStatementBlock((WhileStatementBlock) sb);
+ else if ( sb instanceof IfStatementBlock)
+ return rewriteIfStatementBlock((IfStatementBlock) sb);
+ else if ( sb instanceof ForStatementBlock){
+ // This also includes ParForStatementBlocks
+ return rewriteForStatementBlock((ForStatementBlock) sb);
+ }
+ else if ( sb instanceof FunctionStatementBlock)
+ return rewriteFunctionStatementBlock((FunctionStatementBlock) sb);
+ else {
+ // StatementBlock type (no subclass)
+ selectFederatedExecutionPlan(sb.getHops());
+ }
+ return new ArrayList<>(Collections.singletonList(sb));
+ }
+
+ private ArrayList<StatementBlock> rewriteWhileStatementBlock(WhileStatementBlock whileSB){
+ Hop whilePredicateHop = whileSB.getPredicateHops();
+ selectFederatedExecutionPlan(whilePredicateHop);
+ for ( Statement stm : whileSB.getStatements() ){
+ WhileStatement whileStm = (WhileStatement) stm;
+ whileStm.setBody(rewriteStatementBlocks(whileStm.getBody()));
+ }
+ return new ArrayList<>(Collections.singletonList(whileSB));
+ }
+
+ private ArrayList<StatementBlock> rewriteIfStatementBlock(IfStatementBlock ifSB){
+ selectFederatedExecutionPlan(ifSB.getPredicateHops());
+ for ( Statement statement : ifSB.getStatements() ){
+ IfStatement ifStatement = (IfStatement) statement;
+ ifStatement.setIfBody(rewriteStatementBlocks(ifStatement.getIfBody()));
+ ifStatement.setElseBody(rewriteStatementBlocks(ifStatement.getElseBody()));
+ }
+ return new ArrayList<>(Collections.singletonList(ifSB));
+ }
+
+ private ArrayList<StatementBlock> rewriteForStatementBlock(ForStatementBlock forSB){
+ selectFederatedExecutionPlan(forSB.getFromHops());
+ selectFederatedExecutionPlan(forSB.getToHops());
+ selectFederatedExecutionPlan(forSB.getIncrementHops());
+ for ( Statement statement : forSB.getStatements() ){
+ ForStatement forStatement = ((ForStatement)statement);
+ forStatement.setBody(rewriteStatementBlocks(forStatement.getBody()));
+ }
+ return new ArrayList<>(Collections.singletonList(forSB));
+ }
+
+ private ArrayList<StatementBlock> rewriteFunctionStatementBlock(FunctionStatementBlock funcSB){
+ for ( Statement statement : funcSB.getStatements() ){
+ FunctionStatement funcStm = (FunctionStatement) statement;
+ funcStm.setBody(rewriteStatementBlocks(funcStm.getBody()));
+ }
+ return new ArrayList<>(Collections.singletonList(funcSB));
+ }
+
+ /**
+ * Sets FederatedOutput field of all hops in DAG starting from given root.
+ * The FederatedOutput chosen for root is the minimum cost HopRel found in memo table for the given root.
+ * The FederatedOutput values chosen for the inputs to the root are chosen based on the input dependencies.
+ * @param root hop for which FederatedOutput needs to be set
+ */
+ private void setFinalFedout(Hop root){
+ HopRel optimalRootHopRel = hopRelMemo.get(root.getHopID()).stream().min(Comparator.comparingDouble(HopRel::getCost))
+ .orElseThrow(() -> new DMLException("Hop root " + root + " has no feasible federated output alternatives"));
+ setFinalFedout(root, optimalRootHopRel);
+ }
+
+ /**
+ * Update the FederatedOutput value and cost based on information stored in given rootHopRel.
+ * @param root hop for which FederatedOutput is set
+ * @param rootHopRel from which FederatedOutput value and cost is retrieved
+ */
+ private void setFinalFedout(Hop root, HopRel rootHopRel){
+ updateFederatedOutput(root, rootHopRel);
+ visitInputDependency(rootHopRel);
+ }
+
+ /**
+ * Sets FederatedOutput value for each of the inputs of rootHopRel
+ * @param rootHopRel which has its input values updated
+ */
+ private void visitInputDependency(HopRel rootHopRel){
+ List<HopRel> hopRelInputs = rootHopRel.getInputDependency();
+ for ( HopRel input : hopRelInputs )
+ setFinalFedout(input.getHopRef(), input);
+ }
+
+ /**
+ * Updates FederatedOutput value and cost estimate based on updateHopRel values.
+ * @param root which has its values updated
+ * @param updateHopRel from which the values are retrieved
+ */
+ private void updateFederatedOutput(Hop root, HopRel updateHopRel){
+ root.setFederatedOutput(updateHopRel.getFederatedOutput());
+ root.setFederatedCost(updateHopRel.getCostObject());
+ }
+
+ /**
+ * Select federated execution plan for every Hop in the DAG starting from given roots.
+ * The cost estimates of the hops are also updated when FederatedOutput is updated in the hops.
+ * @param roots starting point for going through the Hop DAG to update the FederatedOutput fields.
+ */
+ private void selectFederatedExecutionPlan(ArrayList<Hop> roots){
+ for ( Hop root : roots )
+ selectFederatedExecutionPlan(root);
+ }
+
+ /**
+ * Select federated execution plan for every Hop in the DAG starting from given root.
+ * @param root starting point for going through the Hop DAG to update the federatedOutput fields
+ */
+ private void selectFederatedExecutionPlan(Hop root){
+ visitFedPlanHop(root);
+ setFinalFedout(root);
+ }
+
+ /**
+ * Go through the Hop DAG and set the FederatedOutput field and cost estimate for each Hop from leaf to given currentHop.
+ * @param currentHop the Hop from which the DAG is visited
+ */
+ private void visitFedPlanHop(Hop currentHop){
+ // If the currentHop is in the hopRelMemo table, it means that it has been visited
+ if ( hopRelMemo.containsKey(currentHop.getHopID()) )
+ return;
+ // If the currentHop has input, then the input should be visited depth-first
+ if ( currentHop.getInput() != null && currentHop.getInput().size() > 0 ){
+ for ( Hop input : currentHop.getInput() )
+ visitFedPlanHop(input);
+ }
+ // Put FOUT, LOUT, and None HopRels into the memo table
+ ArrayList<HopRel> hopRels = new ArrayList<>();
+ if ( isFedInstSupportedHop(currentHop) ){
+ for ( FEDInstruction.FederatedOutput fedoutValue : FEDInstruction.FederatedOutput.values() )
+ if ( isFedOutSupported(currentHop, fedoutValue) )
+ hopRels.add(new HopRel(currentHop,fedoutValue, hopRelMemo));
+ }
+ if ( hopRels.isEmpty() )
+ hopRels.add(new HopRel(currentHop, FEDInstruction.FederatedOutput.NONE, hopRelMemo));
+ hopRelMemo.put(currentHop.getHopID(), hopRels);
+ currentHop.setVisited();
+ }
+
+ /**
+ * Checks if the instructions related to the given hop supports FOUT/LOUT processing.
+ * @param hop to check for federated support
+ * @return true if federated instructions related to hop supports FOUT/LOUT processing
+ */
+ private boolean isFedInstSupportedHop(Hop hop){
+ // The following operations are supported given that the above conditions have not returned already
+ return ( hop instanceof AggBinaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp
+ || hop instanceof AggUnaryOp || hop instanceof TernaryOp || hop instanceof DataOp);
+ }
+
+ /**
+ * Checks if the associatedHop supports the given federated output value.
+ * @param associatedHop to check support of
+ * @param fedOut federated output value
+ * @return true if associatedHop supports fedOut
+ */
+ private boolean isFedOutSupported(Hop associatedHop, FEDInstruction.FederatedOutput fedOut){
+ switch(fedOut){
+ case FOUT:
+ return isFOUTSupported(associatedHop);
+ case LOUT:
+ return isLOUTSupported(associatedHop);
+ case NONE:
+ return false;
+ default:
+ return true;
+ }
+ }
+
+ /**
+ * Checks to see if the associatedHop supports FOUT.
+ * @param associatedHop for which FOUT support is checked
+ * @return true if FOUT is supported by the associatedHop
+ */
+ private boolean isFOUTSupported(Hop associatedHop){
+ // If the output of AggUnaryOp is a scalar, the operation cannot be FOUT
+ if ( associatedHop instanceof AggUnaryOp && associatedHop.isScalar() )
+ return false;
+ // It can only be FOUT if at least one of the inputs are FOUT, except if it is a federated DataOp
+ if ( associatedHop.getInput().stream().noneMatch(
+ input -> hopRelMemo.get(input.getHopID()).stream().anyMatch(HopRel::hasFederatedOutput) )
+ && !associatedHop.isFederatedDataOp() )
+ return false;
+ return true;
+ }
+
+ /**
+ * Checks to see if the associatedHop supports LOUT.
+ * It supports LOUT if the output has no privacy constraints.
+ * @param associatedHop for which LOUT support is checked.
+ * @return true if LOUT is supported by the associatedHop
+ */
+ private boolean isLOUTSupported(Hop associatedHop){
+ return associatedHop.getPrivacy() == null || !associatedHop.getPrivacy().hasConstraints();
+ }
+}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 04cdf32..2e3edb0 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -140,7 +140,6 @@ public class ProgramRewriter
}
if ( OptimizerUtils.FEDERATED_COMPILATION ) {
_dagRuleSet.add( new RewriteFederatedExecution() );
- _sbRuleSet.add( new RewriteFederatedStatementBlocks() );
}
}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
index e6a92ce..75fa735 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
@@ -23,14 +23,8 @@ import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.sysds.api.DMLException;
-import org.apache.sysds.hops.AggBinaryOp;
-import org.apache.sysds.hops.AggUnaryOp;
-import org.apache.sysds.hops.BinaryOp;
-import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
-import org.apache.sysds.hops.ReorgOp;
-import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -40,8 +34,6 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
import org.apache.sysds.runtime.instructions.cp.Data;
-import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
-import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.lineage.LineageItem;
@@ -58,127 +50,24 @@ import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
-import java.util.Collections;
-import java.util.EnumMap;
-import java.util.Map;
import java.util.concurrent.Future;
public class RewriteFederatedExecution extends HopRewriteRule {
+
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
if ( roots == null )
return null;
for ( Hop root : roots )
visitHop(root);
-
- return selectFederatedExecutionPlan(roots);
+ return roots;
}
@Override public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
return null;
}
- /**
- * Select federated execution plan for every Hop in the DAG starting from given roots.
- * @param roots starting point for going through the Hop DAG to update the FederatedOutput fields.
- * @return the list of roots with updated FederatedOutput fields.
- */
- private static ArrayList<Hop> selectFederatedExecutionPlan(ArrayList<Hop> roots){
- for (Hop root : roots){
- root.resetVisitStatus();
- }
- for ( Hop root : roots ){
- visitFedPlanHop(root);
- }
- return roots;
- }
-
- /**
- * Go through the Hop DAG and set the FederatedOutput field for each Hop from leaf to given currentHop.
- * @param currentHop the Hop from which the DAG is visited
- */
- private static void visitFedPlanHop(Hop currentHop){
- if ( currentHop.isVisited() )
- return;
- if ( currentHop.getInput() != null && currentHop.getInput().size() > 0 && !isFederatedDataOp(currentHop) ){
- // Depth first to get to the input
- for ( Hop input : currentHop.getInput() )
- visitFedPlanHop(input);
- } else if ( isFederatedDataOp(currentHop) ) {
- // leaf federated node
- //TODO: This will block the cases where the federated DataOp is based on input that are also federated.
- // This means that the actual federated leaf nodes will never be reached.
- currentHop.setFederatedOutput(FederatedOutput.FOUT);
- }
- if ( ( isFedInstSupportedHop(currentHop) ) ){
- // The Hop can be FOUT or LOUT or None. Check utility of FOUT vs LOUT vs None.
- currentHop.setFederatedOutput(getHighestUtilFedOut(currentHop));
- }
- else
- currentHop.setFederatedOutput(FEDInstruction.FederatedOutput.NONE);
- currentHop.setVisited();
- }
-
- /**
- * Returns the FederatedOutput with the highest utility out of the valid FederatedOutput values.
- * @param hop for which the utility is found
- * @return the FederatedOutput value with highest utility for the given Hop
- */
- private static FederatedOutput getHighestUtilFedOut(Hop hop){
- Map<FederatedOutput,Long> fedOutUtilMap = new EnumMap<>(FederatedOutput.class);
- if ( isFOUTSupported(hop) )
- fedOutUtilMap.put(FederatedOutput.FOUT, getUtilFout());
- if ( hop.getPrivacy() == null || (hop.getPrivacy() != null && !hop.getPrivacy().hasConstraints()) )
- fedOutUtilMap.put(FederatedOutput.LOUT, getUtilLout(hop));
- fedOutUtilMap.put(FederatedOutput.NONE, 0L);
-
- Map.Entry<FederatedOutput, Long> fedOutMax = Collections.max(fedOutUtilMap.entrySet(), Map.Entry.comparingByValue());
- return fedOutMax.getKey();
- }
-
- /**
- * Utility if hop is FOUT. This is a simple version where it always returns 1.
- * @return utility if hop is FOUT
- */
- private static long getUtilFout(){
- //TODO: Make better utility estimation
- return 1;
- }
-
- /**
- * Utility if hop is LOUT. This is a simple version only based on dimensions.
- * @param hop for which utility is calculated
- * @return utility if hop is LOUT
- */
- private static long getUtilLout(Hop hop){
- //TODO: Make better utility estimation
- return -(long)hop.getMemEstimate();
- }
-
- private static boolean isFedInstSupportedHop(Hop hop){
-
- // Check that some input is FOUT, otherwise none of the fed instructions will run unless it is fedinit
- if ( (!isFederatedDataOp(hop)) && hop.getInput().stream().noneMatch(Hop::hasFederatedOutput) )
- return false;
-
- // The following operations are supported given that the above conditions have not returned already
- return ( hop instanceof AggBinaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp
- || hop instanceof AggUnaryOp || hop instanceof TernaryOp || hop instanceof DataOp );
- }
-
- /**
- * Checks to see if the associatedHop supports FOUT.
- * @param associatedHop for which FOUT support is checked
- * @return true if FOUT is supported by the associatedHop
- */
- private static boolean isFOUTSupported(Hop associatedHop){
- // If the output of AggUnaryOp is a scalar, the operation cannot be FOUT
- if ( associatedHop instanceof AggUnaryOp )
- return !associatedHop.isScalar();
- return true;
- }
-
- private static void visitHop(Hop hop){
+ private void visitHop(Hop hop){
if (hop.isVisited())
return;
@@ -206,7 +95,7 @@ public class RewriteFederatedExecution extends HopRewriteRule {
* @hop hop for which privacy constraints are loaded
*/
private static void loadFederatedPrivacyConstraints(Hop hop){
- if ( isFederatedDataOp(hop) && hop.getPrivacy() == null){
+ if ( hop.isFederatedDataOp() && hop.getPrivacy() == null){
try {
PrivacyConstraint privConstraint = unwrapPrivConstraint(sendPrivConstraintRequest(hop));
hop.setPrivacy(privConstraint);
@@ -238,10 +127,6 @@ public class RewriteFederatedExecution extends HopRewriteRule {
return (PrivacyConstraint) privConstraintResponse.getData()[0];
}
- private static boolean isFederatedDataOp(Hop hop){
- return hop instanceof DataOp && ((DataOp) hop).isFederatedData();
- }
-
/**
* FederatedUDF for retrieving privacy constraint of data stored in file name.
*/
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
deleted file mode 100644
index 18b36d5..0000000
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.hops.rewrite;
-
-import org.apache.sysds.parser.StatementBlock;
-
-import java.util.Arrays;
-import java.util.List;
-
-public class RewriteFederatedStatementBlocks extends StatementBlockRewriteRule {
-
- /**
- * Indicates if the rewrite potentially splits dags, which is used
- * for phase ordering of rewrites.
- *
- * @return true if dag splits are possible.
- */
- @Override public boolean createsSplitDag() {
- return false;
- }
-
- /**
- * Handle an arbitrary statement block. Specific type constraints have to be ensured
- * within the individual rewrites. If a rewrite does not apply to individual blocks, it
- * should simply return the input block.
- *
- * @param sb statement block
- * @param state program rewrite status
- * @return list of statement blocks
- */
- @Override
- public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
- return Arrays.asList(sb);
- }
-
- /**
- * Handle a list of statement blocks. Specific type constraints have to be ensured
- * within the individual rewrites. If a rewrite does not require sequence access, it
- * should simply return the input list of statement blocks.
- *
- * @param sbs list of statement blocks
- * @param state program rewrite status
- * @return list of statement blocks
- */
- @Override
- public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus state) {
- return sbs;
- }
-}
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
index 0092a3a..906ed1f 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
@@ -63,7 +63,7 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
@Test
public void simpleBinary() {
- fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
/*
@@ -75,7 +75,7 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
* TOSTRING 1 1 800 0.0625 80
* UnaryOp 1 1 8 0.0625 0.8
*/
- double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
double expectedCost = computeCost + readCost;
@@ -84,9 +84,9 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
@Test
public void ifElseTest(){
- fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
- double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
double expectedCost = ((computeCost + readCost + 0.8 + 0.0625 + 0.0625) / 2) + 0.0625 + 0.8 + 0.0625;
runTest("IfElseCostEstimatorTest.dml", false, expectedCost);
@@ -94,9 +94,9 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
@Test
public void whileTest(){
- fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
- double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
double expectedCost = (computeCost + readCost + 0.0625) * fedCostEstimator.DEFAULT_ITERATION_NUMBER + 0.0625 + 0.8;
runTest("WhileCostEstimatorTest.dml", false, expectedCost);
@@ -104,9 +104,9 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
@Test
public void forLoopTest(){
- fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
- double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625;
double expectedCost = (computeCost + readCost + predicateCost) * 5;
@@ -115,9 +115,9 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
@Test
public void parForLoopTest(){
- fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
- double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625;
double expectedCost = (computeCost + readCost + predicateCost) * 5;
@@ -126,9 +126,9 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
@Test
public void functionTest(){
- fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
- double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
double expectedCost = (computeCost + readCost);
runTest("FunctionCostEstimatorTest.dml", false, expectedCost);
@@ -136,7 +136,7 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
@Test
public void federatedMultiply() {
- fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
fedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;