You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/09/10 18:50:56 UTC
[systemds] branch master updated: [SYSTEMDS-3018] Costing of
Federated Execution Plans
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 6523457 [SYSTEMDS-3018] Costing of Federated Execution Plans
6523457 is described below
commit 65234573b986e8d0ba5bd27bd3d7641b921a7641
Author: sebwrede <sw...@know-center.at>
AuthorDate: Fri Jun 18 17:31:31 2021 +0200
[SYSTEMDS-3018] Costing of Federated Execution Plans
Closes #1367.
---
src/main/java/org/apache/sysds/hops/Hop.java | 93 +++++--
.../codegen/opt/PlanSelectionFuseCostBasedV2.java | 182 +-------------
.../org/apache/sysds/hops/cost/ComputeCost.java | 225 +++++++++++++++++
.../sysds/hops/cost/CostEstimationWrapper.java | 13 +-
.../org/apache/sysds/hops/cost/CostEstimator.java | 72 +++---
.../hops/cost/CostEstimatorStaticRuntime.java | 40 +--
.../org/apache/sysds/hops/cost/FederatedCost.java | 117 +++++++++
.../sysds/hops/cost/FederatedCostEstimator.java | 214 ++++++++++++++++
.../apache/sysds/hops/rewrite/ProgramRewriter.java | 1 +
.../hops/rewrite/RewriteFederatedExecution.java | 133 ++++++++--
.../rewrite/RewriteFederatedStatementBlocks.java | 66 +++++
.../runtime/instructions/FEDInstructionParser.java | 1 +
.../fed/AggregateUnaryFEDInstruction.java | 56 ++++-
.../instructions/fed/AppendFEDInstruction.java | 5 +-
.../instructions/fed/CtableFEDInstruction.java | 4 +-
.../runtime/instructions/fed/FEDInstruction.java | 3 +
.../instructions/fed/ReorgFEDInstruction.java | 36 ++-
.../org/apache/sysds/test/AutomatedTestBase.java | 26 ++
.../fedplanning/FederatedCostEstimatorTest.java | 279 +++++++++++++++++++++
.../fedplanning/FederatedMultiplyPlanningTest.java | 33 +--
.../fedplanning/BinaryCostEstimatorTest.dml | 26 ++
.../FederatedMultiplyCostEstimatorTest.dml | 31 +++
.../fedplanning/ForLoopCostEstimatorTest.dml | 27 ++
.../fedplanning/FunctionCostEstimatorTest.dml | 28 +++
.../fedplanning/IfElseCostEstimatorTest.dml | 30 +++
.../fedplanning/ParForLoopCostEstimatorTest.dml | 27 ++
.../privacy/fedplanning/WhileCostEstimatorTest.dml | 27 ++
27 files changed, 1483 insertions(+), 312 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index ececf52..45fc3af 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -36,6 +36,7 @@ import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.cost.FederatedCost;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.recompile.Recompiler.ResetType;
import org.apache.sysds.lops.CSVReBlock;
@@ -91,8 +92,9 @@ public abstract class Hop implements ParseInfo {
* If it is lout, the output should be retrieved by the coordinator.
*/
protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
+ protected FederatedCost _federatedCost = new FederatedCost();
- // Estimated size for the output produced from this Hop
+ // Estimated size for the output produced from this Hop in bytes
protected double _outputMemEstimate = OptimizerUtils.INVALID_SIZE;
// Estimated size for the entire operation represented by this Hop
@@ -535,7 +537,7 @@ public abstract class Hop implements ParseInfo {
* only use getMemEstimate(), which gives memory required to store
* all inputs and the output.
*
- * @return output size memory estimate
+ * @return output size memory estimate in bytes
*/
protected double getOutputSize() {
return _outputMemEstimate;
@@ -545,14 +547,22 @@ public abstract class Hop implements ParseInfo {
return getInputSize(null);
}
- protected double getInputSize(Collection<String> exclVars) {
+ /**
+ * Get the memory estimate of inputs as the sum of input estimates in bytes.
+ * @param exclVars name of input hops to exclude from the input estimate
+ * @param injectedDefault default memory estimate (bytes) used when the memory estimate of the input is negative
+ * @return input memory estimate in bytes
+ */
+ protected double getInputSize(Collection<String> exclVars, double injectedDefault){
double sum = 0;
int len = _input.size();
for( int i=0; i<len; i++ ) { //for all inputs
Hop hi = _input.get(i);
if( exclVars != null && exclVars.contains(hi.getName()) )
continue;
- double hmout = hi.getOutputMemEstimate();
+ double hmout = hi.getOutputMemEstimate(injectedDefault);
+ if (hmout < 0)
+ hmout = injectedDefault*(Math.max(hi.getDim1(),1) * Math.max(hi.getDim2(),1));
if( hmout > 1024*1024 ) {//for relevant sizes
//check if already included in estimate (if an input is used
//multiple times it is still only required once in memory)
@@ -564,10 +574,19 @@ public abstract class Hop implements ParseInfo {
}
sum += hmout;
}
-
+
return sum;
}
+ /**
+ * Get the memory estimate of inputs as the sum of input estimates in bytes.
+ * @param exclVars name of input hops to exclude from the input estimate
+ * @return input memory estimate in bytes
+ */
+ protected double getInputSize(Collection<String> exclVars) {
+ return getInputSize(exclVars, OptimizerUtils.INVALID_SIZE);
+ }
+
protected double getInputSize( int pos ){
double ret = 0;
if( _input.size()>pos )
@@ -582,12 +601,11 @@ public abstract class Hop implements ParseInfo {
/**
* NOTES:
* * Purpose: Whenever the output dimensions / sparsity of a hop are unknown, this hop
- * should store its worst-case output statistics (if known) in that table. Subsequent
- * hops can then
+ * should store its worst-case output statistics (if known) in that table.
* * Invocation: Intended to be called for ALL root nodes of one Hops DAG with the same
* (initially empty) memo table.
*
- * @return memory estimate
+ * @return memory estimate in bytes
*/
public double getMemEstimate() {
if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
@@ -620,17 +638,46 @@ public abstract class Hop implements ParseInfo {
}
//wrappers for meaningful public names to memory estimates.
-
+
+ /**
+ * Get the memory estimate of inputs as the sum of input estimates in bytes.
+ * @return input memory estimate in bytes
+ */
public double getInputMemEstimate()
{
return getInputSize();
}
-
+
+ /**
+ * Get the memory estimate of inputs as the sum of input estimates in bytes.
+ * @param injectedDefault default memory estimate (bytes) used when the memory estimate of the input is negative
+ * @return input memory estimate in bytes
+ */
+ public double getInputMemEstimate(double injectedDefault){
+ return getInputSize(null, injectedDefault);
+ }
+
+ /**
+ * Output memory estimate in bytes.
+ * @return output memory estimate in bytes
+ */
public double getOutputMemEstimate()
{
return getOutputSize();
}
+ /**
+ * Output memory estimate in bytes with negative memory estimates replaced by the injected default.
+ * The injected default represents the memory estimate per output cell, hence it is multiplied by the estimated
+ * dimensions of the output of the hop.
+ * @param injectedDefault memory estimate to be returned in case the memory estimate defaults to a negative number
+ * @return output memory estimate in bytes
+ */
+ public double getOutputMemEstimate(double injectedDefault)
+ {
+ return Math.max(getOutputMemEstimate(),injectedDefault*(Math.max(getDim1(),1) * Math.max(getDim2(),1)));
+ }
+
public double getIntermediateMemEstimate()
{
return getIntermediateSize();
@@ -823,17 +870,13 @@ public abstract class Hop implements ParseInfo {
* This method only has an effect if FEDERATED_COMPILATION is activated.
*/
protected void updateETFed(){
- if ( _federatedOutput == FederatedOutput.FOUT || _federatedOutput == FederatedOutput.LOUT )
+ if ( _federatedOutput.isForced() )
_etype = ExecType.FED;
}
public boolean isFederated(){
return getExecType() == ExecType.FED;
}
-
- public boolean isFederatedOutput(){
- return _federatedOutput == FederatedOutput.FOUT;
- }
public boolean someInputFederated(){
return getInput().stream().anyMatch(Hop::hasFederatedOutput);
@@ -889,6 +932,26 @@ public abstract class Hop implements ParseInfo {
return _federatedOutput == FederatedOutput.FOUT;
}
+ public boolean hasLocalOutput(){
+ return _federatedOutput == FederatedOutput.LOUT;
+ }
+
+ /**
+ * Check if federated cost has been initialized for this Hop.
+ * @return true if federated cost has been initialized
+ */
+ public boolean federatedCostInitialized(){
+ return _federatedCost.getTotal() > 0;
+ }
+
+ public FederatedCost getFederatedCost(){
+ return _federatedCost;
+ }
+
+ public void setFederatedCost(FederatedCost cost){
+ _federatedCost = cost;
+ }
+
public void setUpdateType(UpdateType update){
_updateType = update;
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 0b20876..c6cfe9f 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -46,17 +46,8 @@ import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
-import org.apache.sysds.hops.BinaryOp;
-import org.apache.sysds.hops.DnnOp;
import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.IndexingOp;
-import org.apache.sysds.hops.LiteralOp;
-import org.apache.sysds.hops.NaryOp;
import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.hops.ParameterizedBuiltinOp;
-import org.apache.sysds.hops.ReorgOp;
-import org.apache.sysds.hops.TernaryOp;
-import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.codegen.opt.ReachabilityGraph.SubProblem;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
import org.apache.sysds.hops.codegen.template.TemplateOuterProduct;
@@ -64,6 +55,7 @@ import org.apache.sysds.hops.codegen.template.TemplateRow;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
import org.apache.sysds.hops.codegen.template.TemplateBase.TemplateType;
+import org.apache.sysds.hops.cost.ComputeCost;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer;
@@ -175,8 +167,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
//obtain hop compute costs per cell once
HashMap<Long, Double> computeCosts = new HashMap<>();
for( Long hopID : part.getPartition() )
- getComputeCosts(memo.getHopRefs().get(hopID), computeCosts);
-
+ computeCosts.put(hopID, ComputeCost.getHOPComputeCost(memo.getHopRefs().get(hopID)));
+
//prepare pruning helpers and prune memo table w/ determined mat points
StaticCosts costs = new StaticCosts(computeCosts, sumComputeCost(computeCosts),
getReadCost(part, memo), getWriteCost(part.getRoots(), memo), minOuterSparsity(part, memo));
@@ -1011,174 +1003,6 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
return costs;
}
- private static void getComputeCosts(Hop current, HashMap<Long, Double> computeCosts)
- {
- //get costs for given hop
- double costs = 1;
- if( current instanceof UnaryOp ) {
- switch( ((UnaryOp)current).getOp() ) {
- case ABS:
- case ROUND:
- case CEIL:
- case FLOOR:
- case SIGN: costs = 1; break;
- case SPROP:
- case SQRT: costs = 2; break;
- case EXP: costs = 18; break;
- case SIGMOID: costs = 21; break;
- case LOG:
- case LOG_NZ: costs = 32; break;
- case NCOL:
- case NROW:
- case PRINT:
- case ASSERT:
- case CAST_AS_BOOLEAN:
- case CAST_AS_DOUBLE:
- case CAST_AS_INT:
- case CAST_AS_MATRIX:
- case CAST_AS_SCALAR: costs = 1; break;
- case SIN: costs = 18; break;
- case COS: costs = 22; break;
- case TAN: costs = 42; break;
- case ASIN: costs = 93; break;
- case ACOS: costs = 103; break;
- case ATAN: costs = 40; break;
- case SINH: costs = 93; break; // TODO:
- case COSH: costs = 103; break;
- case TANH: costs = 40; break;
- case CUMSUM:
- case CUMMIN:
- case CUMMAX:
- case CUMPROD: costs = 1; break;
- case CUMSUMPROD: costs = 2; break;
- default:
- LOG.warn("Cost model not "
- + "implemented yet for: "+((UnaryOp)current).getOp());
- }
- }
- else if( current instanceof BinaryOp ) {
- switch( ((BinaryOp)current).getOp() ) {
- case MULT:
- case PLUS:
- case MINUS:
- case MIN:
- case MAX:
- case AND:
- case OR:
- case EQUAL:
- case NOTEQUAL:
- case LESS:
- case LESSEQUAL:
- case GREATER:
- case GREATEREQUAL:
- case CBIND:
- case RBIND: costs = 1; break;
- case INTDIV: costs = 6; break;
- case MODULUS: costs = 8; break;
- case DIV: costs = 22; break;
- case LOG:
- case LOG_NZ: costs = 32; break;
- case POW: costs = (HopRewriteUtils.isLiteralOfValue(
- current.getInput().get(1), 2) ? 1 : 16); break;
- case MINUS_NZ:
- case MINUS1_MULT: costs = 2; break;
- case MOMENT:
- int type = (int) (current.getInput().get(1) instanceof LiteralOp ?
- HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
- switch( type ) {
- case 0: costs = 1; break; //count
- case 1: costs = 8; break; //mean
- case 2: costs = 16; break; //cm2
- case 3: costs = 31; break; //cm3
- case 4: costs = 51; break; //cm4
- case 5: costs = 16; break; //variance
- }
- break;
- case COV: costs = 23; break;
- default:
- LOG.warn("Cost model not "
- + "implemented yet for: "+((BinaryOp)current).getOp());
- }
- }
- else if( current instanceof TernaryOp ) {
- switch( ((TernaryOp)current).getOp() ) {
- case IFELSE:
- case PLUS_MULT:
- case MINUS_MULT: costs = 2; break;
- case CTABLE: costs = 3; break;
- case MOMENT:
- int type = (int) (current.getInput().get(1) instanceof LiteralOp ?
- HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
- switch( type ) {
- case 0: costs = 2; break; //count
- case 1: costs = 9; break; //mean
- case 2: costs = 17; break; //cm2
- case 3: costs = 32; break; //cm3
- case 4: costs = 52; break; //cm4
- case 5: costs = 17; break; //variance
- }
- break;
- case COV: costs = 23; break;
- default:
- LOG.warn("Cost model not "
- + "implemented yet for: "+((TernaryOp)current).getOp());
- }
- }
- else if( current instanceof NaryOp ) {
- costs = HopRewriteUtils.isNary(current, OpOpN.MIN, OpOpN.MAX, OpOpN.PLUS) ?
- current.getInput().size() : 1;
- }
- else if( current instanceof ParameterizedBuiltinOp ) {
- costs = 1;
- }
- else if( current instanceof IndexingOp ) {
- costs = 1;
- }
- else if( current instanceof ReorgOp ) {
- costs = 1;
- }
- else if( current instanceof DnnOp ) {
- switch( ((DnnOp)current).getOp() ) {
- case BIASADD:
- case BIASMULT:
- costs = 2;
- default:
- LOG.warn("Cost model not "
- + "implemented yet for: "+((DnnOp)current).getOp());
- }
- }
- else if( current instanceof AggBinaryOp ) {
- //outer product template w/ matrix-matrix
- //or row template w/ matrix-vector or matrix-matrix
- costs = 2 * current.getInput().get(0).getDim2();
- if( current.getInput().get(0).dimsKnown(true) )
- costs *= current.getInput().get(0).getSparsity();
- }
- else if( current instanceof AggUnaryOp) {
- switch(((AggUnaryOp)current).getOp()) {
- case SUM: costs = 4; break;
- case SUM_SQ: costs = 5; break;
- case MIN:
- case MAX: costs = 1; break;
- default:
- LOG.warn("Cost model not "
- + "implemented yet for: "+((AggUnaryOp)current).getOp());
- }
- switch(((AggUnaryOp)current).getDirection()) {
- case Col: costs *= Math.max(current.getInput().get(0).getDim1(),1); break;
- case Row: costs *= Math.max(current.getInput().get(0).getDim2(),1); break;
- case RowCol: costs *= getSize(current.getInput().get(0)); break;
- }
- }
-
- //scale by current output size in order to correctly reflect
- //a mix of row and cell operations in the same fused operator
- //(e.g., row template with fused column vector operations)
- costs *= getSize(current);
-
- computeCosts.put(current.getHopID(), costs);
- }
-
private static boolean hasNoRefToMatPoint(long hopID,
MemoTableEntry me, InterestingPoint[] M, boolean[] plan) {
return !InterestingPoint.isMatPoint(M, hopID, me, plan);
diff --git a/src/main/java/org/apache/sysds/hops/cost/ComputeCost.java b/src/main/java/org/apache/sysds/hops/cost/ComputeCost.java
new file mode 100644
index 0000000..3ac64b6
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/ComputeCost.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DnnOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.IndexingOp;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.NaryOp;
+import org.apache.sysds.hops.ParameterizedBuiltinOp;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
+import org.apache.sysds.hops.UnaryOp;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+
+/**
+ * Class with methods estimating compute costs of operations.
+ */
+public class ComputeCost {
+ private static final Log LOG = LogFactory.getLog(ComputeCost.class.getName());
+
+ /**
+ * Get compute cost for given HOP based on the number of floating point operations per output cell
+ * and the total number of output cells.
+ * @param currentHop for which compute cost is returned
+ * @return compute cost of currentHop as number of floating point operations
+ */
+ public static double getHOPComputeCost(Hop currentHop){
+ double costs = 1;
+ if( currentHop instanceof UnaryOp) {
+ switch( ((UnaryOp)currentHop).getOp() ) {
+ case ABS:
+ case ROUND:
+ case CEIL:
+ case FLOOR:
+ case SIGN: costs = 1; break;
+ case SPROP:
+ case SQRT: costs = 2; break;
+ case EXP: costs = 18; break;
+ case SIGMOID: costs = 21; break;
+ case LOG:
+ case LOG_NZ: costs = 32; break;
+ case NCOL:
+ case NROW:
+ case PRINT:
+ case ASSERT:
+ case CAST_AS_BOOLEAN:
+ case CAST_AS_DOUBLE:
+ case CAST_AS_INT:
+ case CAST_AS_MATRIX:
+ case CAST_AS_SCALAR: costs = 1; break;
+ case SIN: costs = 18; break;
+ case COS: costs = 22; break;
+ case TAN: costs = 42; break;
+ case ASIN: costs = 93; break;
+ case ACOS: costs = 103; break;
+ case ATAN: costs = 40; break;
+ case SINH: costs = 93; break; // TODO:
+ case COSH: costs = 103; break;
+ case TANH: costs = 40; break;
+ case CUMSUM:
+ case CUMMIN:
+ case CUMMAX:
+ case CUMPROD: costs = 1; break;
+ case CUMSUMPROD: costs = 2; break;
+ default:
+ LOG.warn("Cost model not "
+ + "implemented yet for: "+((UnaryOp)currentHop).getOp());
+ }
+ }
+ else if( currentHop instanceof BinaryOp) {
+ switch( ((BinaryOp)currentHop).getOp() ) {
+ case MULT:
+ case PLUS:
+ case MINUS:
+ case MIN:
+ case MAX:
+ case AND:
+ case OR:
+ case EQUAL:
+ case NOTEQUAL:
+ case LESS:
+ case LESSEQUAL:
+ case GREATER:
+ case GREATEREQUAL:
+ case CBIND:
+ case RBIND: costs = 1; break;
+ case INTDIV: costs = 6; break;
+ case MODULUS: costs = 8; break;
+ case DIV: costs = 22; break;
+ case LOG:
+ case LOG_NZ: costs = 32; break;
+ case POW: costs = (HopRewriteUtils.isLiteralOfValue(
+ currentHop.getInput().get(1), 2) ? 1 : 16); break;
+ case MINUS_NZ:
+ case MINUS1_MULT: costs = 2; break;
+ case MOMENT:
+ int type = (int) (currentHop.getInput().get(1) instanceof LiteralOp ?
+ HopRewriteUtils.getIntValueSafe((LiteralOp)currentHop.getInput().get(1)) : 2);
+ switch( type ) {
+ case 0: costs = 1; break; //count
+ case 1: costs = 8; break; //mean
+ case 2: costs = 16; break; //cm2
+ case 3: costs = 31; break; //cm3
+ case 4: costs = 51; break; //cm4
+ case 5: costs = 16; break; //variance
+ }
+ break;
+ case COV: costs = 23; break;
+ default:
+ LOG.warn("Cost model not "
+ + "implemented yet for: "+((BinaryOp)currentHop).getOp());
+ }
+ }
+ else if( currentHop instanceof TernaryOp) {
+ switch( ((TernaryOp)currentHop).getOp() ) {
+ case IFELSE:
+ case PLUS_MULT:
+ case MINUS_MULT: costs = 2; break;
+ case CTABLE: costs = 3; break;
+ case MOMENT:
+ int type = (int) (currentHop.getInput().get(1) instanceof LiteralOp ?
+ HopRewriteUtils.getIntValueSafe((LiteralOp)currentHop.getInput().get(1)) : 2);
+ switch( type ) {
+ case 0: costs = 2; break; //count
+ case 1: costs = 9; break; //mean
+ case 2: costs = 17; break; //cm2
+ case 3: costs = 32; break; //cm3
+ case 4: costs = 52; break; //cm4
+ case 5: costs = 17; break; //variance
+ }
+ break;
+ case COV: costs = 23; break;
+ default:
+ LOG.warn("Cost model not "
+ + "implemented yet for: "+((TernaryOp)currentHop).getOp());
+ }
+ }
+ else if( currentHop instanceof NaryOp) {
+ costs = HopRewriteUtils.isNary(currentHop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) ?
+ currentHop.getInput().size() : 1;
+ }
+ else if( currentHop instanceof ParameterizedBuiltinOp) {
+ costs = 1;
+ }
+ else if( currentHop instanceof IndexingOp) {
+ costs = 1;
+ }
+ else if( currentHop instanceof ReorgOp) {
+ costs = 1;
+ }
+ else if( currentHop instanceof DnnOp) {
+ switch( ((DnnOp)currentHop).getOp() ) {
+ case BIASADD:
+ case BIASMULT:
+ costs = 2;
+ default:
+ LOG.warn("Cost model not "
+ + "implemented yet for: "+((DnnOp)currentHop).getOp());
+ }
+ }
+ else if( currentHop instanceof AggBinaryOp) {
+ //outer product template w/ matrix-matrix
+ //or row template w/ matrix-vector or matrix-matrix
+ costs = 2 * currentHop.getInput().get(0).getDim2();
+ if( currentHop.getInput().get(0).dimsKnown(true) )
+ costs *= currentHop.getInput().get(0).getSparsity();
+ }
+ else if( currentHop instanceof AggUnaryOp) {
+ switch(((AggUnaryOp)currentHop).getOp()) {
+ case SUM: costs = 4; break;
+ case SUM_SQ: costs = 5; break;
+ case MIN:
+ case MAX: costs = 1; break;
+ default:
+ LOG.warn("Cost model not "
+ + "implemented yet for: "+((AggUnaryOp)currentHop).getOp());
+ }
+ switch(((AggUnaryOp)currentHop).getDirection()) {
+ case Col: costs *= Math.max(currentHop.getInput().get(0).getDim1(),1); break;
+ case Row: costs *= Math.max(currentHop.getInput().get(0).getDim2(),1); break;
+ case RowCol: costs *= getSize(currentHop.getInput().get(0)); break;
+ }
+ }
+
+ //scale by current output size in order to correctly reflect
+ //a mix of row and cell operations in the same fused operator
+ //(e.g., row template with fused column vector operations)
+ costs *= getSize(currentHop);
+ return costs;
+ }
+
+ /**
+ * Get number of output cells of given hop.
+ * @param hop for which the number of output cells are found
+ * @return number of output cells of given hop
+ */
+ private static long getSize(Hop hop) {
+ return Math.max(hop.getDim1(),1)
+ * Math.max(hop.getDim2(),1);
+ }
+}
diff --git a/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java b/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java
index f8d6a2d..23fcf51 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java
@@ -32,7 +32,6 @@ import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
public class CostEstimationWrapper
{
-
public enum CostType {
NUM_MRJOBS, //based on number of MR jobs, [number MR jobs]
STATIC // based on FLOPS, read/write, etc, [time in sec]
@@ -44,17 +43,13 @@ public class CostEstimationWrapper
private static CostEstimator _costEstim = null;
- static
- {
-
+ static {
//create cost estimator
- try
- {
+ try {
//TODO config parameter?
_costEstim = createCostEstimator(DEFAULT_COSTTYPE);
}
- catch(Exception ex)
- {
+ catch(Exception ex) {
LOG.error("Failed cost estimator initialization.", ex);
}
}
@@ -89,5 +84,5 @@ public class CostEstimationWrapper
default:
throw new DMLRuntimeException("Unknown cost type: "+type);
}
- }
+ }
}
diff --git a/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java b/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
index bb8753c..03948d4 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
@@ -58,9 +58,8 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
-public abstract class CostEstimator
+public abstract class CostEstimator
{
-
protected static final Log LOG = LogFactory.getLog(CostEstimator.class.getName());
private static final int DEFAULT_NUMITER = 15;
@@ -84,7 +83,7 @@ public abstract class CostEstimator
public double getTimeEstimate(ProgramBlock pb, LocalVariableMap vars, HashMap<String,VarStats> stats, boolean recursive) {
//obtain stats from symboltable (e.g., during recompile)
maintainVariableStatistics(vars, stats);
-
+
//get cost estimate
return rGetTimeEstimate(pb, stats, new HashSet<String>(), recursive);
}
@@ -281,10 +280,8 @@ public abstract class CostEstimator
VarStats[] vs = new VarStats[3];
String[] attr = null;
- if( inst instanceof UnaryCPInstruction )
- {
- if( inst instanceof DataGenCPInstruction )
- {
+ if( inst instanceof UnaryCPInstruction ) {
+ if( inst instanceof DataGenCPInstruction ) {
DataGenCPInstruction rinst = (DataGenCPInstruction) inst;
vs[0] = _unknownStats;
vs[1] = _unknownStats;
@@ -298,15 +295,13 @@ public abstract class CostEstimator
type = 1;
attr = new String[]{String.valueOf(type)};
}
- else if( inst instanceof StringInitCPInstruction )
- {
+ else if( inst instanceof StringInitCPInstruction ) {
StringInitCPInstruction rinst = (StringInitCPInstruction) inst;
vs[0] = _unknownStats;
vs[1] = _unknownStats;
vs[2] = stats.get( rinst.output.getName() );
}
- else //general unary
- {
+ else { //general unary
UnaryCPInstruction uinst = (UnaryCPInstruction) inst;
vs[0] = stats.get( uinst.input1.getName() );
vs[1] = _unknownStats;
@@ -317,69 +312,61 @@ public abstract class CostEstimator
if( vs[2] == null ) //scalar output
vs[2] = _scalarStats;
- if( inst instanceof MMTSJCPInstruction )
- {
+ if( inst instanceof MMTSJCPInstruction ) {
String type = ((MMTSJCPInstruction)inst).getMMTSJType().toString();
attr = new String[]{type};
}
- else if( inst instanceof AggregateUnaryCPInstruction )
- {
+ else if( inst instanceof AggregateUnaryCPInstruction ) {
String[] parts = InstructionUtils.getInstructionParts(inst.toString());
String opcode = parts[0];
if( opcode.equals("cm") )
- attr = new String[]{parts[parts.length-2]};
- }
+ attr = new String[]{parts[parts.length-2]};
+ }
}
}
- else if( inst instanceof BinaryCPInstruction )
- {
+ else if( inst instanceof BinaryCPInstruction ) {
BinaryCPInstruction binst = (BinaryCPInstruction) inst;
vs[0] = stats.get( binst.input1.getName() );
vs[1] = stats.get( binst.input2.getName() );
vs[2] = stats.get( binst.output.getName() );
-
- if( vs[0] == null ) //scalar input,
+ if( vs[0] == null ) //scalar input,
vs[0] = _scalarStats;
- if( vs[1] == null ) //scalar input,
+ if( vs[1] == null ) //scalar input,
vs[1] = _scalarStats;
if( vs[2] == null ) //scalar output
vs[2] = _scalarStats;
- }
- else if( inst instanceof AggregateTernaryCPInstruction )
- {
+ }
+ else if( inst instanceof AggregateTernaryCPInstruction ) {
AggregateTernaryCPInstruction binst = (AggregateTernaryCPInstruction) inst;
//of same dimension anyway but missing third input
- vs[0] = stats.get( binst.input1.getName() );
+ vs[0] = stats.get( binst.input1.getName() );
vs[1] = stats.get( binst.input2.getName() );
vs[2] = stats.get( binst.output.getName() );
- if( vs[0] == null ) //scalar input,
+ if( vs[0] == null ) //scalar input,
vs[0] = _scalarStats;
- if( vs[1] == null ) //scalar input,
+ if( vs[1] == null ) //scalar input,
vs[1] = _scalarStats;
if( vs[2] == null ) //scalar output
vs[2] = _scalarStats;
}
- else if( inst instanceof ParameterizedBuiltinCPInstruction )
- {
+ else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
//ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction) inst;
String[] parts = InstructionUtils.getInstructionParts(inst.toString());
String opcode = parts[0];
- if( opcode.equals("groupedagg") )
- {
+ if( opcode.equals("groupedagg") ) {
HashMap<String,String> paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
String fn = paramsMap.get("fn");
String order = paramsMap.get("order");
AggregateOperationTypes type = CMOperator.getAggOpType(fn, order);
attr = new String[]{String.valueOf(type.ordinal())};
}
- else if( opcode.equals("rmempty") )
- {
+ else if( opcode.equals("rmempty") ) {
HashMap<String,String> paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
attr = new String[]{String.valueOf(paramsMap.get("margin").equals("rows")?0:1)};
}
-
+
vs[0] = stats.get( parts[1].substring(7).replaceAll(Lop.VARIABLE_NAME_PLACEHOLDER, "") );
vs[1] = _unknownStats; //TODO
vs[2] = stats.get( parts[parts.length-1] );
@@ -389,16 +376,14 @@ public abstract class CostEstimator
if( vs[2] == null ) //scalar output
vs[2] = _scalarStats;
}
- else if( inst instanceof MultiReturnBuiltinCPInstruction )
- {
+ else if( inst instanceof MultiReturnBuiltinCPInstruction ) {
//applies to qr, lu, eigen (cost computation on input1)
MultiReturnBuiltinCPInstruction minst = (MultiReturnBuiltinCPInstruction) inst;
vs[0] = stats.get( minst.input1.getName() );
vs[1] = stats.get( minst.getOutput(0).getName() );
vs[2] = stats.get( minst.getOutput(1).getName() );
}
- else if( inst instanceof VariableCPInstruction )
- {
+ else if( inst instanceof VariableCPInstruction ) {
setUnknownStats(vs);
VariableCPInstruction varinst = (VariableCPInstruction) inst;
@@ -407,11 +392,10 @@ public abstract class CostEstimator
if( stats.containsKey( varinst.getInput1().getName() ) )
vs[0] = stats.get( varinst.getInput1().getName() );
attr = new String[]{varinst.getInput3().getName()};
- }
+ }
}
- else
- {
- setUnknownStats(vs);
+ else {
+ setUnknownStats(vs);
}
//maintain var status (CP output always inmem)
@@ -426,7 +410,7 @@ public abstract class CostEstimator
private static void setUnknownStats(VarStats[] vs) {
vs[0] = _unknownStats;
vs[1] = _unknownStats;
- vs[2] = _unknownStats;
+ vs[2] = _unknownStats;
}
private static long getNumIterations(HashMap<String,VarStats> stats, ForProgramBlock pb) {
diff --git a/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java b/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java
index e2a4a75..3f29132 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java
@@ -344,30 +344,30 @@ public class CostEstimatorStaticRuntime extends CostEstimator
}
return (leftSparse) ? xcm * (d1m * d1s + 1) : xcm * d1m;
}
- else if( optype.equals("uatrace") || optype.equals("uaktrace") )
- return 2 * d1m * d1n;
- else if( optype.equals("ua+") || optype.equals("uar+") || optype.equals("uac+") ){
- //sparse safe operations
- if( !leftSparse ) //dense
- return d1m * d1n;
- else //sparse
- return d1m * d1n * d1s;
- }
- else if( optype.equals("uak+") || optype.equals("uark+") || optype.equals("uack+"))
- return 4 * d1m * d1n; //1*k+
- else if( optype.equals("uasqk+") || optype.equals("uarsqk+") || optype.equals("uacsqk+"))
+ else if( optype.equals("uatrace") || optype.equals("uaktrace") )
+ return 2 * d1m * d1n;
+ else if( optype.equals("ua+") || optype.equals("uar+") || optype.equals("uac+") ){
+ //sparse safe operations
+ if( !leftSparse ) //dense
+ return d1m * d1n;
+ else //sparse
+ return d1m * d1n * d1s;
+ }
+ else if( optype.equals("uak+") || optype.equals("uark+") || optype.equals("uack+"))
+ return 4 * d1m * d1n; //1*k+
+ else if( optype.equals("uasqk+") || optype.equals("uarsqk+") || optype.equals("uacsqk+"))
return 5 * d1m * d1n; // +1 for multiplication to square term
- else if( optype.equals("uamean") || optype.equals("uarmean") || optype.equals("uacmean"))
+ else if( optype.equals("uamean") || optype.equals("uarmean") || optype.equals("uacmean"))
return 7 * d1m * d1n; //1*k+
- else if( optype.equals("uavar") || optype.equals("uarvar") || optype.equals("uacvar"))
+ else if( optype.equals("uavar") || optype.equals("uarvar") || optype.equals("uacvar"))
return 14 * d1m * d1n;
- else if( optype.equals("uamax") || optype.equals("uarmax") || optype.equals("uacmax")
- || optype.equals("uamin") || optype.equals("uarmin") || optype.equals("uacmin")
- || optype.equals("uarimax") || optype.equals("ua*") )
- return d1m * d1n;
+ else if( optype.equals("uamax") || optype.equals("uarmax") || optype.equals("uacmax")
+ || optype.equals("uamin") || optype.equals("uarmin") || optype.equals("uacmin")
+ || optype.equals("uarimax") || optype.equals("ua*") )
+ return d1m * d1n;
- return 0;
-
+ return 0;
+
case Binary: //opcodes: +, -, *, /, ^ (incl. ^2, *2),
//max, min, solve, ==, !=, <, >, <=, >=
//note: all relational ops are not sparsesafe
diff --git a/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java b/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
new file mode 100644
index 0000000..f4f8db4
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+/**
+ * Class storing execution cost estimates for federated executions with cost estimates split into different categories
+ * such as compute, read, and transfer cost.
+ */
+public class FederatedCost {
+ protected double _computeCost = 0;
+ protected double _readCost = 0;
+ protected double _inputTransferCost = 0;
+ protected double _outputTransferCost = 0;
+ protected double _inputTotalCost = 0;
+
+ public FederatedCost(){}
+
+ public FederatedCost(double readCost, double inputTransferCost, double outputTransferCost,
+ double computeCost, double inputTotalCost){
+ _readCost = readCost;
+ _inputTransferCost = inputTransferCost;
+ _outputTransferCost = outputTransferCost;
+ _computeCost = computeCost;
+ _inputTotalCost = inputTotalCost;
+ }
+
+ /**
+ * Get the total sum of costs stored in this object.
+ * @return total cost
+ */
+ public double getTotal(){
+ return _computeCost + _readCost + _inputTransferCost + _outputTransferCost + _inputTotalCost;
+ }
+
+ /**
+ * Multiply the input costs by the number of times the costs are repeated.
+ * @param repetitionNumber number of repetitions of the costs
+ */
+ public void addRepetitionCost(int repetitionNumber){
+ _inputTotalCost *= repetitionNumber;
+ }
+
+ /**
+ * Get summed input costs.
+ * @return summed input costs
+ */
+ public double getInputTotalCost(){
+ return _inputTotalCost;
+ }
+
+ public void setInputTotalCost(double inputTotalCost){
+ _inputTotalCost = inputTotalCost;
+ }
+
+ /**
+ * Add cost to the stored input cost.
+ * @param additionalCost to add to total input cost
+ */
+ public void addInputTotalCost(double additionalCost){
+ _inputTotalCost += additionalCost;
+ }
+
+ /**
+ * Add total of federatedCost to stored inputTotalCost.
+ * @param federatedCost input cost from which the total is retrieved
+ */
+ public void addInputTotalCost(FederatedCost federatedCost){
+ _inputTotalCost += federatedCost.getTotal();
+ }
+
+ /**
+ * Add costs of FederatedCost object to this object's current costs.
+ * @param additionalCost object to add to this object
+ */
+ public void addFederatedCost(FederatedCost additionalCost){
+ _readCost += additionalCost._readCost;
+ _inputTransferCost += additionalCost._inputTransferCost;
+ _outputTransferCost += additionalCost._outputTransferCost;
+ _computeCost += additionalCost._computeCost;
+ _inputTotalCost += additionalCost._inputTotalCost;
+ }
+
+ @Override
+ public String toString(){
+ StringBuilder builder = new StringBuilder();
+ builder.append(" computeCost: ");
+ builder.append(_computeCost);
+ builder.append("\n readCost: ");
+ builder.append(_readCost);
+ builder.append("\n inputTransferCost: ");
+ builder.append(_inputTransferCost);
+ builder.append("\n outputTransferCost: ");
+ builder.append(_outputTransferCost);
+ builder.append("\n inputTotalCost: ");
+ builder.append(_inputTotalCost);
+ builder.append("\n total cost: ");
+ builder.append(getTotal());
+ return builder.toString();
+ }
+}
diff --git a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
new file mode 100644
index 0000000..3e2f994
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+
+import java.util.ArrayList;
+
+/**
+ * Cost estimator for federated executions with methods and constants for going through DML programs to estimate costs.
+ */
+public class FederatedCostEstimator {
+ public int DEFAULT_MEMORY_ESTIMATE = 8;
+ public int DEFAULT_ITERATION_NUMBER = 15;
+ public double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 1024*1024*1024; //Default network bandwidth in bytes per second
+ public double WORKER_COMPUTE_BANDWITH_FLOPS = 2.5*1024*1024*1024; //Default compute bandwidth in FLOPS
+ public double WORKER_DEGREE_OF_PARALLELISM = 8; //Default number of parallel processes for workers
+ public double WORKER_READ_BANDWIDTH_BYTES_PS = 3.5*1024*1024*1024; //Default read bandwidth in bytes per second
+
+ public boolean printCosts = false; //Temporary for debugging purposes
+
+ /**
+ * Estimate cost of given DML program in bytes.
+ * @param dmlProgram for which the cost is estimated
+ * @return federated cost object with cost estimate in bytes
+ */
+ public FederatedCost costEstimate(DMLProgram dmlProgram){
+ FederatedCost programTotalCost = new FederatedCost();
+ for ( StatementBlock stmBlock : dmlProgram.getStatementBlocks() )
+ programTotalCost.addInputTotalCost(costEstimate(stmBlock).getTotal());
+ return programTotalCost;
+ }
+
+ /**
+ * Cost estimate in bytes of given statement block.
+ * @param sb statement block
+ * @return federated cost object with cost estimate in bytes
+ */
+ private FederatedCost costEstimate(StatementBlock sb){
+ if ( sb instanceof WhileStatementBlock){
+ WhileStatementBlock whileSB = (WhileStatementBlock) sb;
+ FederatedCost whileSBCost = costEstimate(whileSB.getPredicateHops());
+ for ( Statement statement : whileSB.getStatements() ){
+ WhileStatement whileStatement = (WhileStatement) statement;
+ for ( StatementBlock bodyBlock : whileStatement.getBody() )
+ whileSBCost.addInputTotalCost(costEstimate(bodyBlock));
+ }
+ whileSBCost.addRepetitionCost(DEFAULT_ITERATION_NUMBER);
+ return whileSBCost;
+ }
+ else if ( sb instanceof IfStatementBlock){
+ //Get cost of if-block + else-block and divide by two
+ // since only one of the code blocks will be executed in the end
+ IfStatementBlock ifSB = (IfStatementBlock) sb;
+ FederatedCost ifSBCost = new FederatedCost();
+ for ( Statement statement : ifSB.getStatements() ){
+ IfStatement ifStatement = (IfStatement) statement;
+ for ( StatementBlock ifBodySB : ifStatement.getIfBody() )
+ ifSBCost.addInputTotalCost(costEstimate(ifBodySB));
+ for ( StatementBlock elseBodySB : ifStatement.getElseBody() )
+ ifSBCost.addInputTotalCost(costEstimate(elseBodySB));
+ }
+ ifSBCost.setInputTotalCost(ifSBCost.getInputTotalCost()/2);
+ ifSBCost.addInputTotalCost(costEstimate(ifSB.getPredicateHops()));
+ return ifSBCost;
+ }
+ else if ( sb instanceof ForStatementBlock){
+ // This also includes ParForStatementBlocks
+ ForStatementBlock forSB = (ForStatementBlock) sb;
+ ArrayList<Hop> predicateHops = new ArrayList<>();
+ predicateHops.add(forSB.getFromHops());
+ predicateHops.add(forSB.getToHops());
+ predicateHops.add(forSB.getIncrementHops());
+ FederatedCost forSBCost = costEstimate(predicateHops);
+ for ( Statement statement : forSB.getStatements() ){
+ ForStatement forStatement = (ForStatement) statement;
+ for ( StatementBlock forStatementBlockBody : forStatement.getBody() )
+ forSBCost.addInputTotalCost(costEstimate(forStatementBlockBody));
+ }
+ forSBCost.addRepetitionCost(forSB.getEstimateReps());
+ return forSBCost;
+ }
+ else if ( sb instanceof FunctionStatementBlock){
+ FederatedCost funcCost = addInitialInputCost(sb);
+ FunctionStatementBlock funcSB = (FunctionStatementBlock) sb;
+ for(Statement statement : funcSB.getStatements()) {
+ FunctionStatement funcStatement = (FunctionStatement) statement;
+ for ( StatementBlock funcStatementBody : funcStatement.getBody() )
+ funcCost.addInputTotalCost(costEstimate(funcStatementBody));
+ }
+ return funcCost;
+ }
+ else {
+ // StatementBlock type (no subclass)
+ return costEstimate(sb.getHops());
+ }
+ }
+
+ /**
+ * Creates new FederatedCost object and adds all child statement block cost estimates to the object.
+ * @param sb statement block
+ * @return new FederatedCost estimate object with all estimates of child statement blocks added
+ */
+ private FederatedCost addInitialInputCost(StatementBlock sb){
+ FederatedCost basicCost = new FederatedCost();
+ for ( StatementBlock childSB : sb.getDMLProg().getStatementBlocks() )
+ basicCost.addInputTotalCost(costEstimate(childSB).getTotal());
+ return basicCost;
+ }
+
+ /**
+ * Cost estimate in bytes of given list of roots.
+ * The individual cost estimates of the hops are summed.
+ * @param roots list of hops
+ * @return new FederatedCost object with sum of cost estimates of given hops
+ */
+ private FederatedCost costEstimate(ArrayList<Hop> roots){
+ FederatedCost basicCost = new FederatedCost();
+ for ( Hop root : roots )
+ basicCost.addInputTotalCost(costEstimate(root));
+ return basicCost;
+ }
+
+ /**
+ * Return cost estimate in bytes of Hop DAG starting from given root.
+ * @param root of Hop DAG for which cost is estimated
+ * @return cost estimation of Hop DAG starting from given root
+ */
+ private FederatedCost costEstimate(Hop root){
+ if ( root.federatedCostInitialized() )
+ return root.getFederatedCost();
+ else {
+ // If no input has FOUT, the root will be processed by the coordinator
+ boolean hasFederatedInput = root.someInputFederated();
+ //the input cost is included the first time the input hop is used
+ //for additional usage, the additional cost is zero (disregarding potential read cost)
+ double inputCosts = root.getInput().stream()
+ .mapToDouble( in -> in.federatedCostInitialized() ? 0 : costEstimate(in).getTotal() )
+ .sum();
+ double inputTransferCost = hasFederatedInput ? root.getInput().stream()
+ .filter(Hop::hasLocalOutput)
+ .mapToDouble(in -> in.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+ .map(inMem -> inMem/ WORKER_NETWORK_BANDWIDTH_BYTES_PS)
+ .sum() : 0;
+ double computingCost = ComputeCost.getHOPComputeCost(root);
+ if ( hasFederatedInput ){
+ //Find the number of inputs that has FOUT set.
+ int numWorkers = (int)root.getInput().stream().filter(Hop::hasFederatedOutput).count();
+ //divide memory usage by the number of workers the computation would be split to multiplied by
+ //the number of parallel processes at each worker multiplied by the FLOPS of each process
+ //This assumes uniform workload among the workers with FOUT data involved in the operation
+ //and assumes that the degree of parallelism and compute bandwidth are equal for all workers
+ computingCost = computingCost / (numWorkers*WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+ } else computingCost = computingCost / (WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+ //Calculate output transfer cost if the operation is computed at federated workers and the output is forced to the coordinator
+ double outputTransferCost = ( root.hasLocalOutput() && hasFederatedInput ) ?
+ root.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
+ double readCost = root.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_READ_BANDWIDTH_BYTES_PS;
+
+ FederatedCost rootFedCost =
+ new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts);
+ root.setFederatedCost(rootFedCost);
+
+ if ( printCosts )
+ printCosts(root);
+
+ return rootFedCost;
+ }
+ }
+
+ /**
+ * Prints costs and information about root for debugging purposes
+ * @param root hop for which information is printed
+ */
+ private static void printCosts(Hop root){
+ System.out.println("===============================");
+ System.out.println(root);
+ System.out.println("Is federated: " + root.isFederated());
+ System.out.println("Has federated output: " + root.hasFederatedOutput());
+ System.out.println(root.getText());
+ System.out.println("Pure computeCost: " + ComputeCost.getHOPComputeCost(root));
+ System.out.println("Dim1: " + root.getDim1() + " Dim2: " + root.getDim2());
+ System.out.println(root.getFederatedCost().toString());
+ System.out.println("===============================");
+ }
+}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 2e3edb0..04cdf32 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -140,6 +140,7 @@ public class ProgramRewriter
}
if ( OptimizerUtils.FEDERATED_COMPILATION ) {
_dagRuleSet.add( new RewriteFederatedExecution() );
+ _sbRuleSet.add( new RewriteFederatedStatementBlocks() );
}
}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
index 29cda4a..e6a92ce 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
@@ -23,9 +23,14 @@ import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -36,6 +41,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.lineage.LineageItem;
@@ -52,6 +58,9 @@ import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
+import java.util.Collections;
+import java.util.EnumMap;
+import java.util.Map;
import java.util.concurrent.Future;
public class RewriteFederatedExecution extends HopRewriteRule {
@@ -61,18 +70,115 @@ public class RewriteFederatedExecution extends HopRewriteRule {
return null;
for ( Hop root : roots )
visitHop(root);
+
+ return selectFederatedExecutionPlan(roots);
+ }
+
+ @Override public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
+ return null;
+ }
+
+ /**
+ * Select federated execution plan for every Hop in the DAG starting from given roots.
+ * @param roots starting point for going through the Hop DAG to update the FederatedOutput fields.
+ * @return the list of roots with updated FederatedOutput fields.
+ */
+ private static ArrayList<Hop> selectFederatedExecutionPlan(ArrayList<Hop> roots){
+ for (Hop root : roots){
+ root.resetVisitStatus();
+ }
+ for ( Hop root : roots ){
+ visitFedPlanHop(root);
+ }
return roots;
}
- @Override
- public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
- if( root == null )
- return null;
- visitHop(root);
- return root;
+ /**
+ * Go through the Hop DAG and set the FederatedOutput field for each Hop from leaf to given currentHop.
+ * @param currentHop the Hop from which the DAG is visited
+ */
+ private static void visitFedPlanHop(Hop currentHop){
+ if ( currentHop.isVisited() )
+ return;
+ if ( currentHop.getInput() != null && currentHop.getInput().size() > 0 && !isFederatedDataOp(currentHop) ){
+ // Depth first to get to the input
+ for ( Hop input : currentHop.getInput() )
+ visitFedPlanHop(input);
+ } else if ( isFederatedDataOp(currentHop) ) {
+ // leaf federated node
+ //TODO: This will block the cases where the federated DataOp is based on input that are also federated.
+ // This means that the actual federated leaf nodes will never be reached.
+ currentHop.setFederatedOutput(FederatedOutput.FOUT);
+ }
+ if ( ( isFedInstSupportedHop(currentHop) ) ){
+ // The Hop can be FOUT or LOUT or None. Check utility of FOUT vs LOUT vs None.
+ currentHop.setFederatedOutput(getHighestUtilFedOut(currentHop));
+ }
+ else
+ currentHop.setFederatedOutput(FEDInstruction.FederatedOutput.NONE);
+ currentHop.setVisited();
+ }
+
+ /**
+ * Returns the FederatedOutput with the highest utility out of the valid FederatedOutput values.
+ * @param hop for which the utility is found
+ * @return the FederatedOutput value with highest utility for the given Hop
+ */
+ private static FederatedOutput getHighestUtilFedOut(Hop hop){
+ Map<FederatedOutput,Long> fedOutUtilMap = new EnumMap<>(FederatedOutput.class);
+ if ( isFOUTSupported(hop) )
+ fedOutUtilMap.put(FederatedOutput.FOUT, getUtilFout());
+ if ( hop.getPrivacy() == null || (hop.getPrivacy() != null && !hop.getPrivacy().hasConstraints()) )
+ fedOutUtilMap.put(FederatedOutput.LOUT, getUtilLout(hop));
+ fedOutUtilMap.put(FederatedOutput.NONE, 0L);
+
+ Map.Entry<FederatedOutput, Long> fedOutMax = Collections.max(fedOutUtilMap.entrySet(), Map.Entry.comparingByValue());
+ return fedOutMax.getKey();
}
-
- private void visitHop(Hop hop){
+
+ /**
+ * Utility if hop is FOUT. This is a simple version where it always returns 1.
+ * @return utility if hop is FOUT
+ */
+ private static long getUtilFout(){
+ //TODO: Make better utility estimation
+ return 1;
+ }
+
+ /**
+ * Utility if hop is LOUT. This is a simple version only based on dimensions.
+ * @param hop for which utility is calculated
+ * @return utility if hop is LOUT
+ */
+ private static long getUtilLout(Hop hop){
+ //TODO: Make better utility estimation
+ return -(long)hop.getMemEstimate();
+ }
+
+ private static boolean isFedInstSupportedHop(Hop hop){
+
+ // Check that some input is FOUT, otherwise none of the fed instructions will run unless it is fedinit
+ if ( (!isFederatedDataOp(hop)) && hop.getInput().stream().noneMatch(Hop::hasFederatedOutput) )
+ return false;
+
+ // The following operations are supported given that the above conditions have not returned already
+ return ( hop instanceof AggBinaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp
+ || hop instanceof AggUnaryOp || hop instanceof TernaryOp || hop instanceof DataOp );
+ }
+
+ /**
+ * Checks to see if the associatedHop supports FOUT.
+ * @param associatedHop for which FOUT support is checked
+ * @return true if FOUT is supported by the associatedHop
+ */
+ private static boolean isFOUTSupported(Hop associatedHop){
+ // If the output of AggUnaryOp is a scalar, the operation cannot be FOUT
+ if ( associatedHop instanceof AggUnaryOp )
+ return !associatedHop.isScalar();
+ return true;
+ }
+
+ private static void visitHop(Hop hop){
if (hop.isVisited())
return;
@@ -84,15 +190,6 @@ public class RewriteFederatedExecution extends HopRewriteRule {
hop.setVisited();
}
- private static void privacyBasedHopDecision(Hop hop){
- PrivacyPropagator.hopPropagation(hop);
- PrivacyConstraint privacyConstraint = hop.getPrivacy();
- if ( privacyConstraint != null && privacyConstraint.hasConstraints() )
- hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT);
- else if ( hop.someInputFederated() )
- hop.setFederatedOutput(FEDInstruction.FederatedOutput.LOUT);
- }
-
/**
* Get privacy constraints of DataOps from federated worker,
* propagate privacy constraints from input to current hop,
@@ -101,7 +198,7 @@ public class RewriteFederatedExecution extends HopRewriteRule {
*/
private static void privacyBasedHopDecisionWithFedCall(Hop hop){
loadFederatedPrivacyConstraints(hop);
- privacyBasedHopDecision(hop);
+ PrivacyPropagator.hopPropagation(hop);
}
/**
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
new file mode 100644
index 0000000..18b36d5
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.Arrays;
+import java.util.List;
+
+public class RewriteFederatedStatementBlocks extends StatementBlockRewriteRule {
+
+ /**
+ * Indicates if the rewrite potentially splits dags, which is used
+ * for phase ordering of rewrites.
+ *
+ * @return true if dag splits are possible.
+ */
+ @Override public boolean createsSplitDag() {
+ return false;
+ }
+
+ /**
+ * Handle an arbitrary statement block. Specific type constraints have to be ensured
+ * within the individual rewrites. If a rewrite does not apply to individual blocks, it
+ * should simply return the input block.
+ *
+ * @param sb statement block
+ * @param state program rewrite status
+ * @return list of statement blocks
+ */
+ @Override
+ public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
+ return Arrays.asList(sb);
+ }
+
+ /**
+ * Handle a list of statement blocks. Specific type constraints have to be ensured
+ * within the individual rewrites. If a rewrite does not require sequence access, it
+ * should simply return the input list of statement blocks.
+ *
+ * @param sbs list of statement blocks
+ * @param state program rewrite status
+ * @return list of statement blocks
+ */
+ @Override
+ public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus state) {
+ return sbs;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index bae38a2..755287a 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -64,6 +64,7 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "r'" , FEDType.Reorg );
String2FEDInstructionType.put( "rdiag" , FEDType.Reorg );
String2FEDInstructionType.put( "rshape" , FEDType.Reorg );
+ String2FEDInstructionType.put( "rev" , FEDType.Reorg );
// Ternary Instruction Opcodes
String2FEDInstructionType.put( "+*" , FEDType.Ternary);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index fb0647e..2e5366e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
@@ -124,9 +125,60 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()}, true);
map.execute(getTID(), fr1);
- // derive new fed mapping for output
MatrixObject out = ec.getMatrixObject(output);
- out.setFedMapping(in.getFedMapping().copyWithNewID(fr1.getID()));
+ deriveNewOutputFedMapping(in, out, fr1);
+ }
+
+ /**
+ * Set output fed mapping based on federated partitioning and aggregation type.
+ * @param in matrix object from which fed partitioning originates from
+ * @param out matrix object holding the dimensions of the instruction output
+ * @param fr1 federated request holding the instruction execution call
+ */
+ private void deriveNewOutputFedMapping(MatrixObject in, MatrixObject out, FederatedRequest fr1){
+ //Get agg type
+ if ( !(instOpcode.equals("uack+") || instOpcode.equals("uark+")) )
+ throw new DMLRuntimeException("Operation " + instOpcode + " is unknown to FOUT processing");
+ boolean isColAgg = instOpcode.equals("uack+");
+ //Get partition type
+ FederationMap.FType inFtype = in.getFedMapping().getType();
+ //Get fedmap from in
+ FederationMap inputFedMapCopy = in.getFedMapping().copyWithNewID(fr1.getID());
+
+ //if partition type is row and aggregation type is row
+ // then get row dim split from input and use as row dimension and get col dimension from output col dimension
+ // and set FType to ROW
+ if ( inFtype.isRowPartitioned() && !isColAgg ){
+ for ( FederatedRange range : inputFedMapCopy.getFederatedRanges() )
+ range.setEndDim(1,out.getNumColumns());
+ inputFedMapCopy.setType(FederationMap.FType.ROW);
+ }
+ //if partition type is row and aggregation type is col
+ // then get row and col dimension from out and use those dimensions for both federated workers
+ // and set FType to PART
+ //if partition type is col and aggregation type is row
+ // then set row and col dimension from out and use those dimensions for both federated workers
+ // and set FType to PART
+ if ( (inFtype.isRowPartitioned() && isColAgg) || (inFtype.isColPartitioned() && !isColAgg) ){
+ for ( FederatedRange range : inputFedMapCopy.getFederatedRanges() ){
+ range.setBeginDim(0,0);
+ range.setBeginDim(1,0);
+ range.setEndDim(0,out.getNumRows());
+ range.setEndDim(1,out.getNumColumns());
+ }
+ inputFedMapCopy.setType(FederationMap.FType.PART);
+ }
+ //if partition type is col and aggregation type is col
+ // then set row dimension to output and col dimension to in col split
+ // and set FType to COL
+ if ( inFtype.isColPartitioned() && isColAgg ){
+ for ( FederatedRange range : inputFedMapCopy.getFederatedRanges() )
+ range.setEndDim(0,out.getNumRows());
+ inputFedMapCopy.setType(FederationMap.FType.COL);
+ }
+
+ //set out fedmap in the end
+ out.setFedMapping(inputFedMapCopy);
}
/**
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index 825b984..c2e7ab1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -130,8 +130,9 @@ public class AppendFEDInstruction extends BinaryFEDInstruction {
}
else {
throw new DMLRuntimeException("Unsupported federated append: "
- + (mo1.isFederated() ? mo1.getFedMapping().getType().name():"LOCAL") + " "
- + (mo2.isFederated() ? mo2.getFedMapping().getType().name():"LOCAL") + " " + _cbind);
+ + " input 1 FType is " + (mo1.isFederated() ? mo1.getFedMapping().getType().name():"LOCAL")
+ + ", input 2 FType is " + (mo2.isFederated() ? mo2.getFedMapping().getType().name():"LOCAL")
+ + ", and column bind is " + _cbind);
}
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index 2a1cbb1..8795308 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -217,10 +217,10 @@ public class CtableFEDInstruction extends ComputationFEDInstruction {
* @param mo2 input matrix object mo2
* @return boolean indicating if the output can be kept on the federated sites
*/
- private boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) {
+ private static boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) {
MatrixBlock mb = mo2.acquireReadAndRelease();
FederatedRange[] fedRanges = fedMap.getFederatedRanges(); // federated ranges of mo1
- SortedMap<Double, Double> fedDims = new TreeMap<Double, Double>(); // <beginDim, endDim>
+ SortedMap<Double, Double> fedDims = new TreeMap<>(); // <beginDim, endDim>
// collect min and max of the corresponding slices of mo2
IntStream.range(0, fedRanges.length).forEach(i -> {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index f35030f..0e00faa 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -60,6 +60,9 @@ public abstract class FEDInstruction extends Instruction {
public boolean isForcedLocal() {
return this == LOUT;
}
+ public boolean isForced(){
+ return this == FOUT || this == LOUT;
+ }
}
protected final FEDType _fedType;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index 19847a2..cb3074f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -24,6 +24,7 @@ import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.Future;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
@@ -52,7 +53,8 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
public class ReorgFEDInstruction extends UnaryFEDInstruction {
-
+ private static boolean fedoutFlagInString = false;
+
public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut);
}
@@ -80,6 +82,7 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
return new ReorgFEDInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
}
else if ( opcode.equalsIgnoreCase("rev") ) {
+ fedoutFlagInString = parts.length > 3;
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
}
@@ -96,24 +99,34 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
if( !mo1.isFederated() )
throw new DMLRuntimeException("Federated Reorg: "
+ "Federated input expected, but invoked w/ "+mo1.isFederated());
+ if ( !( mo1.isFederated(FederationMap.FType.COL) || mo1.isFederated(FederationMap.FType.ROW)) )
+ throw new DMLRuntimeException("Federation type " + mo1.getFedMapping().getType()
+ + " is not supported for Reorg processing");
if(instOpcode.equals("r'")) {
//execute transpose at federated site
FederatedRequest fr1 = FederationUtils.callInstruction(instString,
output, new CPOperand[] {input1},
new long[] {mo1.getFedMapping().getID()}, true);
- mo1.getFedMapping().execute(getTID(), true, fr1);
+ if (_fedOut != null && !_fedOut.isForcedLocal()){
+ mo1.getFedMapping().execute(getTID(), true, fr1);
- //drive output federated mapping
- MatrixObject out = ec.getMatrixObject(output);
- out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz());
- out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose());
+ //drive output federated mapping
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz());
+ out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose());
+ } else {
+ FederatedRequest getRequest = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
+ Future<FederatedResponse>[] execResponse = mo1.getFedMapping().execute(getTID(), true, fr1, getRequest);
+ ec.setMatrixOutput(output.getName(),
+ FederationUtils.bind(execResponse, mo1.isFederated(FederationMap.FType.COL)));
+ }
}
else if(instOpcode.equalsIgnoreCase("rev")) {
//execute transpose at federated site
FederatedRequest fr1 = FederationUtils.callInstruction(instString,
output, new CPOperand[] {input1},
- new long[] {mo1.getFedMapping().getID()}, true);
+ new long[] {mo1.getFedMapping().getID()}, fedoutFlagInString);
mo1.getFedMapping().execute(getTID(), true, fr1);
if(mo1.isFederated(FederationMap.FType.ROW))
@@ -123,6 +136,11 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getNumRows(), mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz());
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+
+ if ( _fedOut != null && _fedOut.isForcedLocal() ){
+ out.acquireReadAndRelease();
+ out.getFedMapping().cleanup(getTID(), fr1.getID());
+ }
}
else if (instOpcode.equals("rdiag")) {
RdiagResult result;
@@ -158,6 +176,10 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
.set(diagFedMap.getMaxIndexInRange(0), diagFedMap.getMaxIndexInRange(1),
(int) mo1.getBlocksize());
rdiag.setFedMapping(diagFedMap);
+ if ( _fedOut != null && _fedOut.isForcedLocal() ){
+ rdiag.acquireReadAndRelease();
+ rdiag.getFedMapping().cleanup(getTID(), rdiag.getFedMapping().getID());
+ }
}
}
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 3d96cd9..3ac462c 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -36,6 +36,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
+import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
@@ -2097,6 +2098,31 @@ public abstract class AutomatedTestBase {
return false;
}
+ /**
+ * Checks if given strings are all in the set of heavy hitters.
+ * @param str opcodes for which it is checked if all are in the heavy hitters
+ * @return true if all given strings are in the set of heavy hitters
+ */
+ protected boolean heavyHittersContainsAllString(String... str){
+ Set<String> heavyHitters = Statistics.getCPHeavyHitterOpCodes();
+ return Arrays.stream(str).allMatch(heavyHitters::contains);
+ }
+
+ /**
+ * Returns an array of the given opcodes which are not present in the set of heavy hitter opcodes.
+ * @param opcodes for which it is checked if they are among the heavy hitters
+ * @return array of opcodes not found in heavy hitters
+ */
+ protected String[] missingHeavyHitters(String... opcodes){
+ Set<String> heavyHitters = Statistics.getCPHeavyHitterOpCodes();
+ List<String> missingHeavyHitters = new ArrayList<>();
+ for (String opcode : opcodes){
+ if ( !heavyHitters.contains(opcode) )
+ missingHeavyHitters.add(opcode);
+ }
+ return missingHeavyHitters.toArray(new String[0]);
+ }
+
protected boolean heavyHittersContainsString(String str, int minCount) {
int count = 0;
for(String opcode : Statistics.getCPHeavyHitterOpCodes())
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
new file mode 100644
index 0000000..0092a3a
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy.fedplanning;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.NaryOp;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.cost.FederatedCost;
+import org.apache.sysds.hops.cost.FederatedCostEstimator;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DMLTranslator;
+import org.apache.sysds.parser.LanguageException;
+import org.apache.sysds.parser.ParserFactory;
+import org.apache.sysds.parser.ParserWrapper;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Set;
+
+import static org.apache.sysds.common.Types.OpOp2.MULT;
+
+public class FederatedCostEstimatorTest extends AutomatedTestBase {
+
+ private static final String TEST_DIR = "functions/privacy/fedplanning/";
+ private static final String HOME = SCRIPT_DIR + TEST_DIR;
+ private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCostEstimatorTest.class.getSimpleName() + "/";
+ FederatedCostEstimator fedCostEstimator = new FederatedCostEstimator();
+
+ @Override
+ public void setUp() {}
+
+ @Test
+ public void simpleBinary() {
+ fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+
+ /*
+ * HOP Occurences ComputeCost ReadCost ComputeCostFinal ReadCostFinal
+ * ------------------------------------------------------------------------------------------
+ * LiteralOp 16 1 0 0.0625 0
+ * DataGenOp 2 100 64 6.25 6.4
+ * BinaryOp 1 100 1600 6.25 160
+ * TOSTRING 1 1 800 0.0625 80
+ * UnaryOp 1 1 8 0.0625 0.8
+ */
+ double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+
+ double expectedCost = computeCost + readCost;
+ runTest("BinaryCostEstimatorTest.dml", false, expectedCost);
+ }
+
+ @Test
+ public void ifElseTest(){
+ fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+ double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double expectedCost = ((computeCost + readCost + 0.8 + 0.0625 + 0.0625) / 2) + 0.0625 + 0.8 + 0.0625;
+ runTest("IfElseCostEstimatorTest.dml", false, expectedCost);
+ }
+
+ @Test
+ public void whileTest(){
+ fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+ double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double expectedCost = (computeCost + readCost + 0.0625) * fedCostEstimator.DEFAULT_ITERATION_NUMBER + 0.0625 + 0.8;
+ runTest("WhileCostEstimatorTest.dml", false, expectedCost);
+ }
+
+ @Test
+ public void forLoopTest(){
+ fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+ double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625;
+ double expectedCost = (computeCost + readCost + predicateCost) * 5;
+ runTest("ForLoopCostEstimatorTest.dml", false, expectedCost);
+ }
+
+ @Test
+ public void parForLoopTest(){
+ fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+ double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625;
+ double expectedCost = (computeCost + readCost + predicateCost) * 5;
+ runTest("ParForLoopCostEstimatorTest.dml", false, expectedCost);
+ }
+
+ @Test
+ public void functionTest(){
+ fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+ double computeCost = (16+2*100+100+1+1) / (fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) / (fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double expectedCost = (computeCost + readCost);
+ runTest("FunctionCostEstimatorTest.dml", false, expectedCost);
+ }
+
+ @Test
+ public void federatedMultiply() {
+ fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+ fedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;
+
+ double literalOpCost = 10*0.0625;
+ double naryOpCostSpecial = (0.125+2.2);
+ double naryOpCostSpecial2 = (0.25+6.4);
+ double naryOpCost = 4*(0.125+1.6);
+ double reorgOpCost = 6250+80015.2+160030.4;
+ double binaryOpMultCost = 3125+160000;
+ double aggBinaryOpCost = 125000+160015.2+160030.4+190.4;
+ double dataOpCost = 2*(6250+5.6);
+ double dataOpWriteCost = 6.25+100.3;
+
+ double expectedCost = literalOpCost + naryOpCost + naryOpCostSpecial + naryOpCostSpecial2 + reorgOpCost
+ + binaryOpMultCost + aggBinaryOpCost + dataOpCost + dataOpWriteCost;
+ runTest("FederatedMultiplyCostEstimatorTest.dml", false, expectedCost);
+
+ double aggBinaryActualCost = hops.stream()
+ .filter(hop -> hop instanceof AggBinaryOp)
+ .mapToDouble(aggHop -> aggHop.getFederatedCost().getTotal()-aggHop.getFederatedCost().getInputTotalCost())
+ .sum();
+ Assert.assertEquals(aggBinaryOpCost, aggBinaryActualCost, 0.0001);
+
+ double writeActualCost = hops.stream()
+ .filter(hop -> hop instanceof DataOp)
+ .mapToDouble(writeHop -> writeHop.getFederatedCost().getTotal()-writeHop.getFederatedCost().getInputTotalCost())
+ .sum();
+ Assert.assertEquals(dataOpWriteCost+dataOpCost, writeActualCost, 0.0001);
+ }
+
+ Set<Hop> hops = new HashSet<>();
+
+ /**
+ * Recursively adds the hop and its inputs to the set of hops.
+ * @param hop root to be added to set of hops
+ */
+ private void addHop(Hop hop){
+ hops.add(hop);
+ for(Hop inHop : hop.getInput()){
+ addHop(inHop);
+ }
+ }
+
+ /**
+ * Sets dimensions of federated X and Y and sets binary multiplication to FOUT.
+ * @param prog dml program where the HOPS are modified
+ */
+ private void modifyFedouts(DMLProgram prog){
+ prog.getStatementBlocks().forEach(stmBlock -> stmBlock.getHops().forEach(this::addHop));
+ hops.forEach(hop -> {
+ if ( hop instanceof DataOp || (hop instanceof BinaryOp && ((BinaryOp) hop).getOp() == MULT ) ){
+ hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT);
+ hop.setExecType(Types.ExecType.FED);
+ } else {
+ hop.setFederatedOutput(FEDInstruction.FederatedOutput.LOUT);
+ }
+ if ( hop.getOpString().equals("Fed Y") || hop.getOpString().equals("Fed X") ){
+ hop.setDim1(10000);
+ hop.setDim2(10);
+ }
+ });
+ }
+
+ @SuppressWarnings("unused")
+ private void printHopsInfo(){
+ //LiteralOp
+ long literalCount = hops.stream().filter(hop -> hop instanceof LiteralOp).count();
+ System.out.println("LiteralOp Count: " + literalCount);
+ //NaryOp
+ long naryCount = hops.stream().filter(hop -> hop instanceof NaryOp).count();
+ System.out.println("NaryOp Count " + naryCount);
+ //ReorgOp
+ long reorgCount = hops.stream().filter(hop -> hop instanceof ReorgOp).count();
+ System.out.println("ReorgOp Count: " + reorgCount);
+ //BinaryOp
+ long binaryCount = hops.stream().filter(hop -> hop instanceof BinaryOp).count();
+ System.out.println("Binary count: " + binaryCount);
+ //AggBinaryOp
+ long aggBinaryCount = hops.stream().filter(hop -> hop instanceof AggBinaryOp).count();
+ System.out.println("AggBinaryOp Count: " + aggBinaryCount);
+ //DataOp
+ long dataOpCount = hops.stream().filter(hop -> hop instanceof DataOp).count();
+ System.out.println("DataOp Count: " + dataOpCount);
+
+ hops.stream().map(Hop::getClass).distinct().forEach(System.out::println);
+ }
+
+ private void runTest( String scriptFilename, boolean expectedException, double expectedCost ) {
+ boolean raisedException = false;
+ try
+ {
+ setTestConfig(scriptFilename);
+ String dmlScriptString = readScript(scriptFilename);
+
+ //parsing, dependency analysis and constructing hops (step 3 and 4 in DMLScript.java)
+ ParserWrapper parser = ParserFactory.createParser();
+ DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>());
+ DMLTranslator dmlt = new DMLTranslator(prog);
+ dmlt.liveVariableAnalysis(prog);
+ dmlt.validateParseTree(prog);
+ dmlt.constructHops(prog);
+ if ( scriptFilename.equals("FederatedMultiplyCostEstimatorTest.dml")){
+ modifyFedouts(prog);
+ dmlt.rewriteHopsDAG(prog);
+ hops = new HashSet<>();
+ prog.getStatementBlocks().forEach(stmBlock -> stmBlock.getHops().forEach(this::addHop));
+ }
+
+ FederatedCost actualCost = fedCostEstimator.costEstimate(prog);
+ Assert.assertEquals(expectedCost, actualCost.getTotal(), 0.0001);
+ }
+ catch(LanguageException ex) {
+ raisedException = true;
+ if(raisedException!=expectedException)
+ ex.printStackTrace();
+ }
+ catch(Exception ex2) {
+ ex2.printStackTrace();
+ throw new RuntimeException(ex2);
+ }
+
+ Assert.assertEquals("Expected exception does not match raised exception",
+ expectedException, raisedException);
+ }
+
+ private void setTestConfig(String scriptFilename) throws FileNotFoundException {
+ int index = scriptFilename.lastIndexOf(".dml");
+ String testName = scriptFilename.substring(0, index > 0 ? index : scriptFilename.length());
+ TestConfiguration testConfig = new TestConfiguration(TEST_CLASS_DIR, testName, new String[] {});
+ addTestConfiguration(testName, testConfig);
+ loadTestConfiguration(testConfig);
+
+ DMLConfig conf = new DMLConfig(getCurConfigFile().getPath());
+ ConfigurationManager.setLocalConfig(conf);
+ }
+
+ private static String readScript(String scriptFilename) throws IOException {
+ return DMLScript.readDMLScript(true, HOME + scriptFilename);
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 342a26b..e0ef884 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -35,7 +35,7 @@ import org.apache.sysds.test.TestUtils;
import java.util.Arrays;
import java.util.Collection;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
@@ -76,34 +76,40 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
@Test
public void federatedMultiplyCP() {
- federatedTwoMatricesSingleNodeTest(TEST_NAME);
+ String[] expectedHeavyHitters = new String[]{"fed_*", "fed_fedinit", "fed_r'", "fed_ba+*"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME, expectedHeavyHitters);
}
@Test
public void federatedRowSum(){
- federatedTwoMatricesSingleNodeTest(TEST_NAME_2);
+ String[] expectedHeavyHitters = new String[]{"fed_*", "fed_r'", "fed_fedinit", "fed_ba+*", "fed_uark+"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_2, expectedHeavyHitters);
}
@Test
public void federatedTernarySequence(){
- federatedTwoMatricesSingleNodeTest(TEST_NAME_3);
+ String[] expectedHeavyHitters = new String[]{"fed_+*", "fed_1-*", "fed_fedinit", "fed_uak+"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_3, expectedHeavyHitters);
}
@Test
public void federatedAggregateBinarySequence(){
cols = rows;
- federatedTwoMatricesSingleNodeTest(TEST_NAME_4);
+ String[] expectedHeavyHitters = new String[]{"fed_ba+*", "fed_*", "fed_fedinit"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_4, expectedHeavyHitters);
}
@Test
public void federatedAggregateBinaryColFedSequence(){
cols = rows;
- federatedTwoMatricesSingleNodeTest(TEST_NAME_5);
+ String[] expectedHeavyHitters = new String[]{"fed_ba+*","fed_*","fed_fedinit"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_5, expectedHeavyHitters);
}
@Test
public void federatedAggregateBinarySequence2(){
- federatedTwoMatricesSingleNodeTest(TEST_NAME_6);
+ String[] expectedHeavyHitters = new String[]{"fed_ba+*","fed_fedinit"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_6, expectedHeavyHitters);
}
private void writeStandardMatrix(String matrixName, long seed){
@@ -166,11 +172,11 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
}
}
- private void federatedTwoMatricesSingleNodeTest(String testName){
- federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName);
+ private void federatedTwoMatricesSingleNodeTest(String testName, String[] expectedHeavyHitters){
+ federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName, expectedHeavyHitters);
}
- private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName) {
+ private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName, String[] expectedHeavyHitters) {
OptimizerUtils.FEDERATED_COMPILATION = true;
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
@@ -213,10 +219,9 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
// compare via files
compareResults(1e-9);
- if ( testName.equals(TEST_NAME_3) )
- assertTrue(heavyHittersContainsString("fed_+*", "fed_1-*"));
- else
- assertTrue(heavyHittersContainsString("fed_*", "fed_ba+*"));
+ if (!heavyHittersContainsAllString(expectedHeavyHitters))
+ fail("The following expected heavy hitters are missing: "
+ + Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
TestUtils.shutdownThreads(t1, t2);
diff --git a/src/test/scripts/functions/privacy/fedplanning/BinaryCostEstimatorTest.dml b/src/test/scripts/functions/privacy/fedplanning/BinaryCostEstimatorTest.dml
new file mode 100644
index 0000000..1899614
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/BinaryCostEstimatorTest.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+a = matrix (1, rows=10, cols=10);
+b = matrix (2, rows=10, cols=10);
+c = a * b;
+print(toString(c));
diff --git a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyCostEstimatorTest.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyCostEstimatorTest.dml
new file mode 100644
index 0000000..dfeacec
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyCostEstimatorTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+ph = "placeholder"
+rows = 10000
+cols = 10
+X = federated(addresses=list(ph, ph),
+ ranges=list(list(0, 0), list(rows / 2, cols), list(rows / 2, 0), list(rows, cols)))
+Y = federated(addresses=list(ph, ph),
+ ranges=list(list(0, 0), list(rows/2, cols), list(rows / 2, 0), list(rows, cols)))
+Z0 = X * Y
+Z = t(Z0) %*% X
+write(Z, ph)
diff --git a/src/test/scripts/functions/privacy/fedplanning/ForLoopCostEstimatorTest.dml b/src/test/scripts/functions/privacy/fedplanning/ForLoopCostEstimatorTest.dml
new file mode 100644
index 0000000..b80e745
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/ForLoopCostEstimatorTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+for ( i in 1:5 ){
+ a = matrix (1, rows=10, cols=10);
+ b = matrix (2, rows=10, cols=10);
+ c = a * b;
+ print(toString(c));
+}
\ No newline at end of file
diff --git a/src/test/scripts/functions/privacy/fedplanning/FunctionCostEstimatorTest.dml b/src/test/scripts/functions/privacy/fedplanning/FunctionCostEstimatorTest.dml
new file mode 100644
index 0000000..1f0d876
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/FunctionCostEstimatorTest.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+multiplication = function (){
+ a = matrix (1, rows=10, cols=10);
+ b = matrix (2, rows=10, cols=10);
+ c = a * b;
+ print(toString(c));
+}
+multiplication();
\ No newline at end of file
diff --git a/src/test/scripts/functions/privacy/fedplanning/IfElseCostEstimatorTest.dml b/src/test/scripts/functions/privacy/fedplanning/IfElseCostEstimatorTest.dml
new file mode 100644
index 0000000..b0194e3
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/IfElseCostEstimatorTest.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ( 1 ){
+ a = matrix (1, rows=10, cols=10);
+ b = matrix (2, rows=10, cols=10);
+ c = a * b;
+ print(toString(c));
+}
+else {
+ print("No result");
+}
\ No newline at end of file
diff --git a/src/test/scripts/functions/privacy/fedplanning/ParForLoopCostEstimatorTest.dml b/src/test/scripts/functions/privacy/fedplanning/ParForLoopCostEstimatorTest.dml
new file mode 100644
index 0000000..32a49ec
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/ParForLoopCostEstimatorTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+parfor ( i in 1:5 ){
+ a = matrix (1, rows=10, cols=10);
+ b = matrix (2, rows=10, cols=10);
+ c = a * b;
+ print(toString(c));
+}
\ No newline at end of file
diff --git a/src/test/scripts/functions/privacy/fedplanning/WhileCostEstimatorTest.dml b/src/test/scripts/functions/privacy/fedplanning/WhileCostEstimatorTest.dml
new file mode 100644
index 0000000..faea01d
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/WhileCostEstimatorTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+while ( 1 ){
+ a = matrix (1, rows=10, cols=10);
+ b = matrix (2, rows=10, cols=10);
+ c = a * b;
+ print(toString(c));
+}
\ No newline at end of file