You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2022/02/24 10:06:09 UTC
[systemds] branch main updated: [SYSTEMDS-3018] Federated Cost Estimation for Repetitions
This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 37b3e93 [SYSTEMDS-3018] Federated Cost Estimation for Repetitions
37b3e93 is described below
commit 37b3e934ddc9d686d8f6ede9f689038a998ff87a
Author: sebwrede <sw...@know-center.at>
AuthorDate: Tue Feb 15 11:48:33 2022 +0100
[SYSTEMDS-3018] Federated Cost Estimation for Repetitions
This commit changes the federated plan cost estimation when while/for/if statement blocks are used.
Closes #1547.
---
src/main/java/org/apache/sysds/hops/Hop.java | 22 +++
.../java/org/apache/sysds/hops/OptimizerUtils.java | 23 +--
.../org/apache/sysds/hops/cost/CostEstimator.java | 4 +-
.../org/apache/sysds/hops/cost/FederatedCost.java | 38 ++---
.../sysds/hops/cost/FederatedCostEstimator.java | 62 +++----
.../java/org/apache/sysds/hops/cost/HopRel.java | 2 +-
.../hops/ipa/IPAPassRewriteFederatedPlan.java | 7 +-
.../java/org/apache/sysds/parser/DMLProgram.java | 6 +
.../org/apache/sysds/parser/ForStatementBlock.java | 18 +-
.../sysds/parser/FunctionStatementBlock.java | 9 +
.../org/apache/sysds/parser/IfStatementBlock.java | 15 +-
.../org/apache/sysds/parser/StatementBlock.java | 33 ++++
.../apache/sysds/parser/WhileStatementBlock.java | 15 +-
.../runtime/controlprogram/WhileProgramBlock.java | 6 +-
.../fedplanning/FederatedCostEstimatorTest.java | 181 ++++++++++++++++-----
15 files changed, 326 insertions(+), 115 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index 037bfa5..003492f 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -93,6 +93,7 @@ public abstract class Hop implements ParseInfo {
*/
protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
protected FederatedCost _federatedCost = new FederatedCost();
+ protected double repetitions = 1;
/**
* Field defining if prefetch should be activated for operation.
@@ -996,6 +997,15 @@ public abstract class Hop implements ParseInfo {
_federatedCost = cost;
}
+ /**
+ * Reset federated cost of this hop and all children of this hop.
+ */
+ public void resetFederatedCost(){
+ _federatedCost = new FederatedCost();
+ for ( Hop input : getInput() )
+ input.resetFederatedCost();
+ }
+
public void setUpdateType(UpdateType update){
_updateType = update;
}
@@ -1539,6 +1549,18 @@ public abstract class Hop implements ParseInfo {
return ret;
}
+ public void updateRepetitionEstimates(double repetitions){
+ if ( !federatedCostInitialized() ){
+ this.repetitions = repetitions;
+ for ( Hop input : getInput() )
+ input.updateRepetitionEstimates(repetitions);
+ }
+ }
+
+ public double getRepetitions(){
+ return repetitions;
+ }
+
/**
* Clones the attributes of that and copies it over to this.
*
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 47b5822..4d48df6 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -1289,17 +1289,18 @@ public class OptimizerUtils
if( fpb.getStatementBlock()==null )
return defaultValue;
ForStatementBlock fsb = (ForStatementBlock) fpb.getStatementBlock();
- try {
- HashMap<Long,Long> memo = new HashMap<>();
- long from = rEvalSimpleLongExpression(fsb.getFromHops().getInput().get(0), memo);
- long to = rEvalSimpleLongExpression(fsb.getToHops().getInput().get(0), memo);
- long increment = (fsb.getIncrementHops()==null) ? (from < to) ? 1 : -1 :
- rEvalSimpleLongExpression(fsb.getIncrementHops().getInput().get(0), memo);
- if( from != Long.MAX_VALUE && to != Long.MAX_VALUE && increment != Long.MAX_VALUE )
- return (int)Math.ceil(((double)(to-from+1))/increment);
- }
- catch(Exception ex){}
- return defaultValue;
+ return getNumIterations(fsb, defaultValue);
+ }
+
+ public static long getNumIterations(ForStatementBlock fsb, long defaultValue){
+ HashMap<Long,Long> memo = new HashMap<>();
+ long from = rEvalSimpleLongExpression(fsb.getFromHops().getInput().get(0), memo);
+ long to = rEvalSimpleLongExpression(fsb.getToHops().getInput().get(0), memo);
+ long increment = (fsb.getIncrementHops()==null) ? (from < to) ? 1 : -1 :
+ rEvalSimpleLongExpression(fsb.getIncrementHops().getInput().get(0), memo);
+ if( from != Long.MAX_VALUE && to != Long.MAX_VALUE && increment != Long.MAX_VALUE )
+ return (int)Math.ceil(((double)(to-from+1))/increment);
+ else return defaultValue;
}
public static long getNumIterations(ForProgramBlock fpb, LocalVariableMap vars, long defaultValue) {
diff --git a/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java b/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
index 03948d4..497b807 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
@@ -116,7 +116,7 @@ public abstract class CostEstimator
for( ProgramBlock pb2 : tmp.getChildBlocks() )
ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
- ret *= getNumIterations(stats, tmp);
+ ret *= getNumIterations(tmp);
}
else if ( pb instanceof FunctionProgramBlock ) {
FunctionProgramBlock tmp = (FunctionProgramBlock) pb;
@@ -413,7 +413,7 @@ public abstract class CostEstimator
vs[2] = _unknownStats;
}
- private static long getNumIterations(HashMap<String,VarStats> stats, ForProgramBlock pb) {
+ private static long getNumIterations(ForProgramBlock pb) {
return OptimizerUtils.getNumIterations(pb, DEFAULT_NUMITER);
}
diff --git a/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java b/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
index f4f8db4..8831fdc 100644
--- a/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
@@ -30,15 +30,20 @@ public class FederatedCost {
protected double _outputTransferCost = 0;
protected double _inputTotalCost = 0;
+ protected double _repetitions = 1;
+ protected double _totalCost;
+
public FederatedCost(){}
public FederatedCost(double readCost, double inputTransferCost, double outputTransferCost,
- double computeCost, double inputTotalCost){
+ double computeCost, double inputTotalCost, double repetitions){
_readCost = readCost;
_inputTransferCost = inputTransferCost;
_outputTransferCost = outputTransferCost;
_computeCost = computeCost;
_inputTotalCost = inputTotalCost;
+ _repetitions = repetitions;
+ _totalCost = calcTotal();
}
/**
@@ -46,15 +51,15 @@ public class FederatedCost {
* @return total cost
*/
public double getTotal(){
- return _computeCost + _readCost + _inputTransferCost + _outputTransferCost + _inputTotalCost;
+ return _totalCost;
}
- /**
- * Multiply the input costs by the number of times the costs are repeated.
- * @param repetitionNumber number of repetitions of the costs
- */
- public void addRepetitionCost(int repetitionNumber){
- _inputTotalCost *= repetitionNumber;
+ private double calcTotal(){
+ return (_computeCost + _readCost + _inputTransferCost + _outputTransferCost) * _repetitions + _inputTotalCost;
+ }
+
+ private void updateTotal(){
+ this._totalCost = calcTotal();
}
/**
@@ -75,6 +80,7 @@ public class FederatedCost {
*/
public void addInputTotalCost(double additionalCost){
_inputTotalCost += additionalCost;
+ updateTotal();
}
/**
@@ -82,19 +88,7 @@ public class FederatedCost {
* @param federatedCost input cost from which the total is retrieved
*/
public void addInputTotalCost(FederatedCost federatedCost){
- _inputTotalCost += federatedCost.getTotal();
- }
-
- /**
- * Add costs of FederatedCost object to this object's current costs.
- * @param additionalCost object to add to this object
- */
- public void addFederatedCost(FederatedCost additionalCost){
- _readCost += additionalCost._readCost;
- _inputTransferCost += additionalCost._inputTransferCost;
- _outputTransferCost += additionalCost._outputTransferCost;
- _computeCost += additionalCost._computeCost;
- _inputTotalCost += additionalCost._inputTotalCost;
+ addInputTotalCost(federatedCost.getTotal());
}
@Override
@@ -110,6 +104,8 @@ public class FederatedCost {
builder.append(_outputTransferCost);
builder.append("\n inputTotalCost: ");
builder.append(_inputTotalCost);
+ builder.append("\n repetitions: ");
+ builder.append(_repetitions);
builder.append("\n total cost: ");
builder.append(getTotal());
return builder.toString();
diff --git a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
index 96a33d4..400caa9 100644
--- a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
@@ -19,6 +19,8 @@
package org.apache.sysds.hops.cost;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.ipa.MemoTable;
import org.apache.sysds.parser.DMLProgram;
@@ -39,14 +41,13 @@ import java.util.ArrayList;
* Cost estimator for federated executions with methods and constants for going through DML programs to estimate costs.
*/
public class FederatedCostEstimator {
- public int DEFAULT_MEMORY_ESTIMATE = 8;
- public int DEFAULT_ITERATION_NUMBER = 15;
- public double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 1024*1024*1024; //Default network bandwidth in bytes per second
- public double WORKER_COMPUTE_BANDWIDTH_FLOPS = 2.5*1024*1024*1024; //Default compute bandwidth in FLOPS
- public double WORKER_DEGREE_OF_PARALLELISM = 8; //Default number of parallel processes for workers
- public double WORKER_READ_BANDWIDTH_BYTES_PS = 3.5*1024*1024*1024; //Default read bandwidth in bytes per second
+ private static final Log LOG = LogFactory.getLog(FederatedCostEstimator.class.getName());
- public boolean printCosts = false; //Temporary for debugging purposes
+ public static int DEFAULT_MEMORY_ESTIMATE = 8;
+ public static double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 1024*1024*1024; //Default network bandwidth in bytes per second
+ public static double WORKER_COMPUTE_BANDWIDTH_FLOPS = 2.5*1024*1024*1024; //Default compute bandwidth in FLOPS
+ public static double WORKER_DEGREE_OF_PARALLELISM = 8; //Default number of parallel processes for workers
+ public static double WORKER_READ_BANDWIDTH_BYTES_PS = 3.5*1024*1024*1024; //Default read bandwidth in bytes per second
/**
* Estimate cost of given DML program in bytes.
@@ -54,6 +55,7 @@ public class FederatedCostEstimator {
* @return federated cost object with cost estimate in bytes
*/
public FederatedCost costEstimate(DMLProgram dmlProgram){
+ dmlProgram.updateRepetitionEstimates();
FederatedCost programTotalCost = new FederatedCost();
for ( StatementBlock stmBlock : dmlProgram.getStatementBlocks() )
programTotalCost.addInputTotalCost(costEstimate(stmBlock).getTotal());
@@ -74,12 +76,9 @@ public class FederatedCostEstimator {
for ( StatementBlock bodyBlock : whileStatement.getBody() )
whileSBCost.addInputTotalCost(costEstimate(bodyBlock));
}
- whileSBCost.addRepetitionCost(DEFAULT_ITERATION_NUMBER);
return whileSBCost;
}
else if ( sb instanceof IfStatementBlock){
- //Get cost of if-block + else-block and divide by two
- // since only one of the code blocks will be executed in the end
IfStatementBlock ifSB = (IfStatementBlock) sb;
FederatedCost ifSBCost = new FederatedCost();
for ( Statement statement : ifSB.getStatements() ){
@@ -89,7 +88,6 @@ public class FederatedCostEstimator {
for ( StatementBlock elseBodySB : ifStatement.getElseBody() )
ifSBCost.addInputTotalCost(costEstimate(elseBodySB));
}
- ifSBCost.setInputTotalCost(ifSBCost.getInputTotalCost()/2);
ifSBCost.addInputTotalCost(costEstimate(ifSB.getPredicateHops()));
return ifSBCost;
}
@@ -106,7 +104,6 @@ public class FederatedCostEstimator {
for ( StatementBlock forStatementBlockBody : forStatement.getBody() )
forSBCost.addInputTotalCost(costEstimate(forStatementBlockBody));
}
- forSBCost.addRepetitionCost(forSB.getEstimateReps());
return forSBCost;
}
else if ( sb instanceof FunctionStatementBlock){
@@ -182,12 +179,13 @@ public class FederatedCostEstimator {
root.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
double readCost = root.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_READ_BANDWIDTH_BYTES_PS;
+ double rootRepetitions = root.getRepetitions();
FederatedCost rootFedCost =
- new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts);
+ new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts, rootRepetitions);
root.setFederatedCost(rootFedCost);
- if ( printCosts )
- printCosts(root);
+ if ( LOG.isDebugEnabled() )
+ LOG.debug(getCostInfo(root));
return rootFedCost;
}
@@ -199,7 +197,7 @@ public class FederatedCostEstimator {
* @param hopRelMemo memo table of HopRels for calculating input costs
* @return cost estimation of Hop DAG starting from given root HopRel
*/
- public FederatedCost costEstimate(HopRel root, MemoTable hopRelMemo){
+ public static FederatedCost costEstimate(HopRel root, MemoTable hopRelMemo){
// Check if root is in memo table.
if ( hopRelMemo.containsHopRel(root) ){
return root.getCostObject();
@@ -234,7 +232,8 @@ public class FederatedCostEstimator {
root.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
double readCost = root.hopRef.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_READ_BANDWIDTH_BYTES_PS;
- return new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts);
+ double rootRepetitions = root.hopRef.getRepetitions();
+ return new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts, rootRepetitions);
}
}
@@ -247,7 +246,7 @@ public class FederatedCostEstimator {
* @param root hopRel for which cost is estimated
* @return input transfer cost estimate
*/
- private double inputTransferCostEstimate(boolean hasFederatedInput, HopRel root){
+ private static double inputTransferCostEstimate(boolean hasFederatedInput, HopRel root){
if ( hasFederatedInput )
return root.inputDependency.stream()
.filter(input -> (root.hopRef.isFederatedDataOp()) ? input.hasFederatedOutput() : input.hasLocalOutput() )
@@ -275,18 +274,21 @@ public class FederatedCostEstimator {
}
/**
- * Prints costs and information about root for debugging purposes
- * @param root hop for which information is printed
+ * Return costs and information about root for debugging purposes.
+ * @param root hop for which information is returned
+ * @return information about root cost
*/
- private static void printCosts(Hop root){
- System.out.println("===============================");
- System.out.println(root);
- System.out.println("Is federated: " + root.isFederated());
- System.out.println("Has federated output: " + root.hasFederatedOutput());
- System.out.println(root.getText());
- System.out.println("Pure computeCost: " + ComputeCost.getHOPComputeCost(root));
- System.out.println("Dim1: " + root.getDim1() + " Dim2: " + root.getDim2());
- System.out.println(root.getFederatedCost().toString());
- System.out.println("===============================");
+ private static String getCostInfo(Hop root){
+ String sep = System.getProperty("line.separator");
+ StringBuilder costInfo = new StringBuilder();
+ costInfo
+ .append(root).append(sep)
+ .append("Is federated: ").append(root.isFederated())
+ .append(" Has federated output: ").append(root.hasFederatedOutput())
+ .append(root.getText()).append(sep)
+ .append("Pure computeCost: " + ComputeCost.getHOPComputeCost(root))
+ .append(" Dim1: " + root.getDim1() + " Dim2: " + root.getDim2()).append(sep)
+ .append(root.getFederatedCost().toString()).append(sep);
+ return costInfo.toString();
}
}
diff --git a/src/main/java/org/apache/sysds/hops/cost/HopRel.java b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
index b1cc6dd..bd5ee85 100644
--- a/src/main/java/org/apache/sysds/hops/cost/HopRel.java
+++ b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
@@ -55,7 +55,7 @@ public class HopRel {
hopRef = associatedHop;
this.fedOut = fedOut;
setInputDependency(hopRelMemo);
- cost = new FederatedCostEstimator().costEstimate(this, hopRelMemo);
+ cost = FederatedCostEstimator.costEstimate(this, hopRelMemo);
}
/**
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
index db313af..383be42 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
@@ -69,6 +69,10 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
*/
private final static List<Hop> terminalHops = new ArrayList<>();
+ public List<Hop> getTerminalHops(){
+ return terminalHops;
+ }
+
/**
* Indicates if an IPA pass is applicable for the current configuration.
* The configuration depends on OptimizerUtils.FEDERATED_COMPILATION.
@@ -93,6 +97,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
@Override
public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph,
FunctionCallSizeInfo fcallSizes) {
+ prog.updateRepetitionEstimates();
rewriteStatementBlocks(prog, prog.getStatementBlocks());
setFinalFedouts();
return false;
@@ -178,7 +183,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
}
private ArrayList<StatementBlock> rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb) {
- if(sb.getHops() != null && !sb.getHops().isEmpty()) {
+ if(sb.hasHops()) {
for(Hop sbHop : sb.getHops()) {
if(sbHop instanceof FunctionOp) {
String funcName = ((FunctionOp) sbHop).getFunctionName();
diff --git a/src/main/java/org/apache/sysds/parser/DMLProgram.java b/src/main/java/org/apache/sysds/parser/DMLProgram.java
index 498a59d..2edffa7 100644
--- a/src/main/java/org/apache/sysds/parser/DMLProgram.java
+++ b/src/main/java/org/apache/sysds/parser/DMLProgram.java
@@ -201,6 +201,12 @@ public class DMLProgram
throw new RuntimeException(ex);
}
}
+
+ public void updateRepetitionEstimates(){
+ for ( StatementBlock stmBlock : getStatementBlocks() ){
+ stmBlock.updateRepetitionEstimates(1);
+ }
+ }
@Override
public String toString(){
diff --git a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
index 1acd1ac..b21b9b5 100644
--- a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
@@ -447,6 +447,20 @@ public class ForStatementBlock extends StatementBlock
}
}
- return 10;
+ return (int) DEFAULT_LOOP_REPETITIONS;
}
-}
\ No newline at end of file
+
+ @Override
+ public void updateRepetitionEstimates(double repetitions){
+ this.repetitions = repetitions * getEstimateReps();
+ _fromHops.updateRepetitionEstimates(this.repetitions);
+ _toHops.updateRepetitionEstimates(this.repetitions);
+ _incrementHops.updateRepetitionEstimates(this.repetitions);
+ for(Statement statement : getStatements()) {
+ List<StatementBlock> children = ((ForStatement) statement).getBody();
+ for ( StatementBlock stmBlock : children ){
+ stmBlock.updateRepetitionEstimates(this.repetitions);
+ }
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
index cc7ab64..ed70c69 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
@@ -258,4 +258,13 @@ public class FunctionStatementBlock extends StatementBlock implements FunctionBl
return ProgramConverter
.createDeepCopyFunctionStatementBlock(this, new HashSet<>(), new HashSet<>());
}
+
+ @Override
+ public void updateRepetitionEstimates(double repetitions){
+ for (Statement stm : getStatements()){
+ for (StatementBlock block : ((FunctionStatement) stm).getBody()){
+ block.updateRepetitionEstimates(repetitions);
+ }
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/parser/IfStatementBlock.java b/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
index 4762a14..bae78ca 100644
--- a/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
@@ -502,7 +502,20 @@ public class IfStatementBlock extends StatementBlock
liveInReturn.addVariables(_liveIn);
return liveInReturn;
}
-
+
+ @Override
+ public void updateRepetitionEstimates(double repetitions){
+ this.repetitions = repetitions;
+ getPredicateHops().updateRepetitionEstimates(this.repetitions);
+ for ( Statement statement : getStatements() ){
+ IfStatement ifStatement = (IfStatement) statement;
+ double blockLevelReps = repetitions / 2;
+ for ( StatementBlock ifBodySB : ifStatement.getIfBody() )
+ ifBodySB.updateRepetitionEstimates(blockLevelReps);
+ for ( StatementBlock elseBodySB : ifStatement.getElseBody() )
+ elseBodySB.updateRepetitionEstimates(blockLevelReps);
+ }
+ }
/////////
// materialized hops recompilation flags
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 4f8cd1b..6e9545c 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -29,6 +29,7 @@ import java.util.Map.Entry;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.rewrite.StatementBlockRewriteRule;
@@ -64,6 +65,9 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo
private boolean _splitDag = false;
private boolean _nondeterministic = false;
+ protected double repetitions = 1;
+ public final static double DEFAULT_LOOP_REPETITIONS = 10;
+
public StatementBlock() {
_ID = getNextSBID();
_name = "SB"+_ID;
@@ -1238,6 +1242,35 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo
return liveInReturn;
}
+ public boolean hasHops(){
+ return getHops() != null && !getHops().isEmpty();
+ }
+
+ /**
+ * Updates the repetition estimate for this statement block
+ * and all contained hops. FunctionStatementBlocks are loaded
+ * from the function dictionary and repetitions are estimated
+ * for the contained statement blocks.
+ *
+ * This method is overridden in the subclasses of StatementBlock.
+ * @param repetitions estimated for this statement block
+ */
+ public void updateRepetitionEstimates(double repetitions){
+ this.repetitions = repetitions;
+ if ( hasHops() ){
+ for ( Hop root : getHops() ){
+ // Set repetitionNum for hops recursively
+ if(root instanceof FunctionOp) {
+ String funcName = ((FunctionOp) root).getFunctionName();
+ FunctionStatementBlock sbFuncBlock = getDMLProg().getBuiltinFunctionDictionary().getFunction(funcName);
+ sbFuncBlock.updateRepetitionEstimates(repetitions);
+ }
+ else
+ root.updateRepetitionEstimates(repetitions);
+ }
+ }
+ }
+
///////////////////////////////////////////////////////////////
// validate error handling (consistent for all expressions)
diff --git a/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java b/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
index 7a09242..b28e682 100644
--- a/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
@@ -22,6 +22,7 @@ package org.apache.sysds.parser;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.List;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.Hop;
@@ -317,6 +318,18 @@ public class WhileStatementBlock extends StatementBlock
return liveInReturn;
}
+
+ @Override
+ public void updateRepetitionEstimates(double repetitions){
+ this.repetitions = repetitions * DEFAULT_LOOP_REPETITIONS;
+ getPredicateHops().updateRepetitionEstimates(this.repetitions);
+ for(Statement statement : getStatements()) {
+ List<StatementBlock> children = ((WhileStatement)statement).getBody();
+ for ( StatementBlock stmBlock : children ){
+ stmBlock.updateRepetitionEstimates(this.repetitions);
+ }
+ }
+ }
/////////
// materialized hops recompilation flags
@@ -331,4 +344,4 @@ public class WhileStatementBlock extends StatementBlock
public boolean requiresPredicateRecompilation() {
return _requiresPredicateRecompile;
}
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
index cc916de..4695b94 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
@@ -20,8 +20,12 @@
package org.apache.sysds.runtime.controlprogram;
import java.util.ArrayList;
+import java.util.List;
import org.apache.sysds.hops.Hop;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ValueType;
@@ -151,4 +155,4 @@ public class WhileProgramBlock extends ProgramBlock
public String printBlockErrorLocation(){
return "ERROR: Runtime error in while program block generated from while statement block between lines " + _beginLine + " and " + _endLine + " -- ";
}
-}
\ No newline at end of file
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
index 906ed1f..b8ad989 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test.functions.privacy.fedplanning;
+import net.jcip.annotations.NotThreadSafe;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
@@ -32,15 +33,21 @@ import org.apache.sysds.hops.NaryOp;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.cost.FederatedCost;
import org.apache.sysds.hops.cost.FederatedCostEstimator;
+import org.apache.sysds.hops.ipa.FunctionCallGraph;
+import org.apache.sysds.hops.ipa.IPAPassRewriteFederatedPlan;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
import org.apache.sysds.parser.LanguageException;
import org.apache.sysds.parser.ParserFactory;
import org.apache.sysds.parser.ParserWrapper;
+import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
+import org.junit.After;
import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
import org.junit.Test;
import java.io.FileNotFoundException;
@@ -51,6 +58,7 @@ import java.util.Set;
import static org.apache.sysds.common.Types.OpOp2.MULT;
+@NotThreadSafe
public class FederatedCostEstimatorTest extends AutomatedTestBase {
private static final String TEST_DIR = "functions/privacy/fedplanning/";
@@ -58,13 +66,36 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCostEstimatorTest.class.getSimpleName() + "/";
FederatedCostEstimator fedCostEstimator = new FederatedCostEstimator();
+ private static double COMPUTE_FLOPS;
+ private static double READ_PS;
+ private static double NETWORK_PS;
+
@Override
public void setUp() {}
+ @BeforeClass
+ public static void storeConstants(){
+ COMPUTE_FLOPS = FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS;
+ READ_PS = FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS;
+ NETWORK_PS = FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS;
+ }
+
+ @Before
+ public void setConstants(){
+ FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
+ FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+ FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;
+ }
+
+ @After
+ public void resetConstants(){
+ FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = COMPUTE_FLOPS;
+ FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = READ_PS;
+ FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = NETWORK_PS;
+ }
+
@Test
public void simpleBinary() {
- fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
- fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
/*
* HOP Occurences ComputeCost ReadCost ComputeCostFinal ReadCostFinal
@@ -75,70 +106,87 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
* TOSTRING 1 1 800 0.0625 80
* UnaryOp 1 1 8 0.0625 0.8
*/
- double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
- double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
double expectedCost = computeCost + readCost;
runTest("BinaryCostEstimatorTest.dml", false, expectedCost);
}
@Test
+ public void simpleBinaryHopRelTest() {
+ runHopRelTest("BinaryCostEstimatorTest.dml", false);
+ }
+
+ @Test
public void ifElseTest(){
- fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
- fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
- double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
- double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
double expectedCost = ((computeCost + readCost + 0.8 + 0.0625 + 0.0625) / 2) + 0.0625 + 0.8 + 0.0625;
runTest("IfElseCostEstimatorTest.dml", false, expectedCost);
}
@Test
+ public void ifElseHopRelTest(){
+ runHopRelTest("IfElseCostEstimatorTest.dml", false);
+ }
+
+ @Test
public void whileTest(){
- fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
- fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
- double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
- double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
- double expectedCost = (computeCost + readCost + 0.0625) * fedCostEstimator.DEFAULT_ITERATION_NUMBER + 0.0625 + 0.8;
+ double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double expectedCost = (computeCost + readCost + 0.0625 + 0.0625 + 0.8) * StatementBlock.DEFAULT_LOOP_REPETITIONS;
runTest("WhileCostEstimatorTest.dml", false, expectedCost);
}
@Test
+ public void whileHopRelTest(){
+ runHopRelTest("WhileCostEstimatorTest.dml", false);
+ }
+
+ @Test
public void forLoopTest(){
- fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
- fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
- double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
- double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625;
double expectedCost = (computeCost + readCost + predicateCost) * 5;
runTest("ForLoopCostEstimatorTest.dml", false, expectedCost);
}
@Test
+ public void forLoopHopRelTest(){
+ runHopRelTest("ForLoopCostEstimatorTest.dml", false);
+ }
+
+ @Test
public void parForLoopTest(){
- fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
- fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
- double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
- double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625;
double expectedCost = (computeCost + readCost + predicateCost) * 5;
runTest("ParForLoopCostEstimatorTest.dml", false, expectedCost);
}
@Test
+ public void parForLoopHopRelTest(){
+ runHopRelTest("ParForLoopCostEstimatorTest.dml", false);
+ }
+
+ @Test
public void functionTest(){
- fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
- fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
- double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
- double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
double expectedCost = (computeCost + readCost);
runTest("FunctionCostEstimatorTest.dml", false, expectedCost);
}
@Test
+ public void functionHopRelTest(){
+ runHopRelTest("FunctionCostEstimatorTest.dml", false);
+ }
+
+ @Test
public void federatedMultiply() {
- fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
- fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
- fedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;
double literalOpCost = 10*0.0625;
double naryOpCostSpecial = (0.125+2.2);
@@ -224,27 +272,72 @@ public class FederatedCostEstimatorTest extends AutomatedTestBase {
hops.stream().map(Hop::getClass).distinct().forEach(System.out::println);
}
+ private DMLProgram testSetup(String scriptFilename) throws IOException{
+ setTestConfig(scriptFilename);
+ String dmlScriptString = readScript(scriptFilename);
+
+ //parsing, dependency analysis and constructing hops (step 3 and 4 in DMLScript.java)
+ ParserWrapper parser = ParserFactory.createParser();
+ DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>());
+ DMLTranslator dmlt = new DMLTranslator(prog);
+ dmlt.liveVariableAnalysis(prog);
+ dmlt.validateParseTree(prog);
+ dmlt.constructHops(prog);
+ if ( scriptFilename.equals("FederatedMultiplyCostEstimatorTest.dml")){
+ modifyFedouts(prog);
+ dmlt.rewriteHopsDAG(prog);
+ hops = new HashSet<>();
+ prog.getStatementBlocks().forEach(stmBlock -> stmBlock.getHops().forEach(this::addHop));
+ }
+ return prog;
+ }
+
+ private void compareResults(DMLProgram prog) {
+ IPAPassRewriteFederatedPlan rewriter = new IPAPassRewriteFederatedPlan();
+ rewriter.rewriteProgram(prog, new FunctionCallGraph(prog), null);
+
+ double actualCost = 0;
+ for ( Hop root : rewriter.getTerminalHops() ){
+ actualCost += root.getFederatedCost().getTotal();
+ }
+
+
+ rewriter.getTerminalHops().forEach(Hop::resetFederatedCost);
+ fedCostEstimator = new FederatedCostEstimator();
+ double expectedCost = 0;
+ for ( Hop root : rewriter.getTerminalHops() )
+ expectedCost += fedCostEstimator.costEstimate(root).getTotal();
+ Assert.assertEquals(expectedCost, actualCost, 0.0001);
+ }
+
+ private void runHopRelTest( String scriptFilename, boolean expectedException ) {
+ boolean raisedException = false;
+ try
+ {
+ DMLProgram prog = testSetup(scriptFilename);
+ compareResults(prog);
+ }
+ catch(LanguageException ex) {
+ raisedException = true;
+ if(raisedException!=expectedException)
+ ex.printStackTrace();
+ }
+ catch(Exception ex2) {
+ ex2.printStackTrace();
+ throw new RuntimeException(ex2);
+ }
+
+ Assert.assertEquals("Expected exception does not match raised exception",
+ expectedException, raisedException);
+ }
+
private void runTest( String scriptFilename, boolean expectedException, double expectedCost ) {
boolean raisedException = false;
try
{
- setTestConfig(scriptFilename);
- String dmlScriptString = readScript(scriptFilename);
-
- //parsing, dependency analysis and constructing hops (step 3 and 4 in DMLScript.java)
- ParserWrapper parser = ParserFactory.createParser();
- DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>());
- DMLTranslator dmlt = new DMLTranslator(prog);
- dmlt.liveVariableAnalysis(prog);
- dmlt.validateParseTree(prog);
- dmlt.constructHops(prog);
- if ( scriptFilename.equals("FederatedMultiplyCostEstimatorTest.dml")){
- modifyFedouts(prog);
- dmlt.rewriteHopsDAG(prog);
- hops = new HashSet<>();
- prog.getStatementBlocks().forEach(stmBlock -> stmBlock.getHops().forEach(this::addHop));
- }
+ DMLProgram prog = testSetup(scriptFilename);
+ fedCostEstimator = new FederatedCostEstimator();
FederatedCost actualCost = fedCostEstimator.costEstimate(prog);
Assert.assertEquals(expectedCost, actualCost.getTotal(), 0.0001);
}