You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2022/05/13 13:16:52 UTC
[systemds] branch main updated: [SYSTEMDS-3018] Federated Planner Forced ExecType And FedOut Info
This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new aba1707852 [SYSTEMDS-3018] Federated Planner Forced ExecType And FedOut Info
aba1707852 is described below
commit aba1707852d546a9c46ada3f185824df063494bd
Author: sebwrede <sw...@know-center.at>
AuthorDate: Wed May 11 16:08:53 2022 +0200
[SYSTEMDS-3018] Federated Planner Forced ExecType And FedOut Info
Applying this commit will:
1) Add Forced ExecType and Other Adjustments of ExecType
2) Add FedOut Info to Explain Hops Output
Closes #1612.
---
.../java/org/apache/sysds/hops/AggUnaryOp.java | 3 ++
src/main/java/org/apache/sysds/hops/BinaryOp.java | 4 +--
src/main/java/org/apache/sysds/hops/Hop.java | 18 ------------
.../java/org/apache/sysds/hops/cost/HopRel.java | 32 ++++++++++++++++------
.../hops/fedplanner/FederatedPlannerCostbased.java | 14 +++++++++-
.../apache/sysds/hops/fedplanner/MemoTable.java | 15 ++++++++++
.../runtime/instructions/FEDInstructionParser.java | 3 ++
src/main/java/org/apache/sysds/utils/Explain.java | 22 +++++++++++++++
.../fedplanning/FederatedL2SVMPlanningTest.java | 3 +-
9 files changed, 82 insertions(+), 32 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index c461b69bac..23439b182e 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -608,6 +608,9 @@ public class AggUnaryOp extends MultiThreadedHop
ExecType et_input = input1.optFindExecType();
// Because ternary aggregate are not supported on GPU
et_input = et_input == ExecType.GPU ? ExecType.CP : et_input;
+ // If forced ExecType is FED, it means that the federated planner updated the ExecType and
+ // execution may fail if ExecType is not FED
+ et_input = (getForcedExecType() == ExecType.FED) ? ExecType.FED : et_input;
return new TernaryAggregate(in1, in2, in3, AggOp.SUM,
OpOp2.MULT, _direction, getDataType(), ValueType.FP64, et_input, k);
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 791c3bdfbd..2346eeebfe 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -755,11 +755,9 @@ public class BinaryOp extends MultiThreadedHop {
checkAndSetInvalidCPDimsAndSize();
}
- updateETFed();
-
//spark-specific decision refinement (execute unary scalar w/ spark input and
//single parent also in spark because it's likely cheap and reduces intermediates)
- if( transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP
+ if( transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP && _etypeForced != ExecType.FED
&& getDataType().isMatrix() && (dt1.isScalar() || dt2.isScalar())
&& supportsMatrixScalarOperations() //scalar operations
&& !(getInput().get(dt1.isScalar()?1:0) instanceof DataOp) //input is not checkpoint
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index 4ce9a4b90f..e1e4fcc8d4 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -909,24 +909,6 @@ public abstract class Hop implements ParseInfo {
return et;
}
- /**
- * Update the execution type if input is federated.
- * This method only has an effect if FEDERATED_COMPILATION is activated.
- * Federated compilation is activated in OptimizerUtils.
- */
- public void updateETFed() {
- boolean localOut = hasLocalOutput();
- boolean fedIn = getInput().stream().anyMatch(
- in -> in.hasFederatedOutput() && !(in.prefetchActivated() && localOut));
- if( isFederatedDataOp() || fedIn ){
- setForcedExecType(ExecType.FED);
- //TODO: Temporary solution where _etype is set directly
- // since forcedExecType for BinaryOp may be overwritten
- // if updateETFed is not called from optFindExecType.
- _etype = ExecType.FED;
- }
- }
-
/**
* Checks if ExecType is federated.
* @return true if ExecType is federated
diff --git a/src/main/java/org/apache/sysds/hops/cost/HopRel.java b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
index 89a0f7cb50..70785950ca 100644
--- a/src/main/java/org/apache/sysds/hops/cost/HopRel.java
+++ b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
@@ -20,6 +20,7 @@
package org.apache.sysds.hops.cost;
import org.apache.sysds.api.DMLException;
+import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
@@ -43,8 +44,9 @@ import java.util.stream.Collectors;
public class HopRel {
protected final Hop hopRef;
protected final FEDInstruction.FederatedOutput fedOut;
+ protected ExecType execType;
protected FTypes.FType fType;
- protected final FederatedCost cost;
+ protected FederatedCost cost;
protected final Set<Long> costPointerSet = new HashSet<>();
protected List<Hop> inputHops;
protected List<HopRel> inputDependency = new ArrayList<>();
@@ -70,6 +72,13 @@ public class HopRel {
this(associatedHop, fedOut, null, hopRelMemo, inputs);
}
+ private HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, FType fType, List<Hop> inputs){
+ hopRef = associatedHop;
+ this.fedOut = fedOut;
+ this.fType = fType;
+ this.inputHops = inputs;
+ }
+
/**
* Constructs a HopRel with input dependency and cost estimate based on entries in hopRelMemo.
* @param associatedHop hop associated with this HopRel
@@ -79,21 +88,17 @@ public class HopRel {
* @param inputs hop inputs which input dependencies and cost is based on
*/
public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, FType fType, MemoTable hopRelMemo, ArrayList<Hop> inputs){
- hopRef = associatedHop;
- this.fedOut = fedOut;
- this.fType = fType;
- this.inputHops = inputs;
+ this(associatedHop, fedOut, fType, inputs);
setInputDependency(hopRelMemo);
cost = FederatedCostEstimator.costEstimate(this, hopRelMemo);
+ setExecType();
}
public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, FType fType, MemoTable hopRelMemo, List<Hop> inputs, List<FType> inputDependency){
- hopRef = associatedHop;
- this.fedOut = fedOut;
- this.inputHops = inputs;
- this.fType = fType;
+ this(associatedHop, fedOut, fType, inputs);
setInputFTypeDependency(inputs, inputDependency, hopRelMemo);
cost = FederatedCostEstimator.costEstimate(this, hopRelMemo);
+ setExecType();
}
private void setInputFTypeDependency(List<Hop> inputs, List<FType> inputDependency, MemoTable hopRelMemo){
@@ -103,6 +108,11 @@ public class HopRel {
validateInputDependency();
}
+ private void setExecType(){
+ if ( inputDependency.stream().anyMatch(HopRel::hasFederatedOutput) )
+ execType = ExecType.FED;
+ }
+
/**
* Adds hopID to set of hops pointing to this HopRel.
* By storing the hopID it can later be determined if the cost
@@ -154,6 +164,10 @@ public class HopRel {
this.fType = fType;
}
+ public ExecType getExecType(){
+ return execType;
+ }
+
/**
* Returns FOUT HopRel for given hop found in hopRelMemo or returns null if HopRel not found.
* @param hop to look for in hopRelMemo
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
index a809d2bafd..e9a25206f8 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
@@ -30,6 +30,7 @@ import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.hops.DataOp;
@@ -53,6 +54,8 @@ import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+import org.apache.sysds.utils.Explain;
+import org.apache.sysds.utils.Explain.ExplainType;
public class FederatedPlannerCostbased extends AFederatedPlanner {
private static final Log LOG = LogFactory.getLog(FederatedPlannerCostbased.class.getName());
@@ -77,6 +80,7 @@ public class FederatedPlannerCostbased extends AFederatedPlanner {
prog.updateRepetitionEstimates();
rewriteStatementBlocks(prog, prog.getStatementBlocks());
setFinalFedouts();
+ updateExplain();
}
/**
@@ -215,7 +219,6 @@ public class FederatedPlannerCostbased extends AFederatedPlanner {
updateFederatedOutput(root, rootHopRel);
visitInputDependency(rootHopRel);
}
- root.updateETFed();
}
/**
@@ -238,6 +241,7 @@ public class FederatedPlannerCostbased extends AFederatedPlanner {
private void updateFederatedOutput(Hop root, HopRel updateHopRel) {
root.setFederatedOutput(updateHopRel.getFederatedOutput());
root.setFederatedCost(updateHopRel.getCostObject());
+ root.setForcedExecType(updateHopRel.getExecType());
forceFixedFedOut(root);
LOG.trace("Updated fedOut to " + updateHopRel.getFederatedOutput() + " for hop "
+ root.getHopID() + " opcode: " + root.getOpString());
@@ -394,6 +398,14 @@ public class FederatedPlannerCostbased extends AFederatedPlanner {
}
}
+ /**
+ * Add hopRelMemo to Explain class to get explain info related to federated enumeration.
+ */
+ private void updateExplain(){
+ if (DMLScript.EXPLAIN == ExplainType.HOPS)
+ Explain.setMemo(hopRelMemo);
+ }
+
/**
* Write HOP visit to debug log if debug is activated.
* @param currentHop hop written to log
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
index 6b9da0f400..5b399bd499 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
@@ -23,6 +23,7 @@ import org.apache.sysds.api.DMLException;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.cost.HopRel;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import java.util.Comparator;
import java.util.HashMap;
@@ -46,6 +47,20 @@ public class MemoTable {
*/
private final static Map<Long, List<HopRel>> hopRelMemo = new HashMap<>();
+ /**
+ * Get list of strings representing the different
+ * hopRel federated outputs related to root hop.
+ * @param root for which HopRel fedouts are found
+ * @return federated output values as strings
+ */
+ public List<String> getFedOutAlternatives(Hop root){
+ if ( !containsHop(root) )
+ return null;
+ else return hopRelMemo.get(root.getHopID()).stream()
+ .map(HopRel::getFederatedOutput)
+ .map(FEDInstruction.FederatedOutput::name).collect(Collectors.toList());
+ }
+
/**
* Get the HopRel with minimum cost for given root hop
* @param root hop for which minimum cost HopRel is found
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 2fde0a0fbc..58ab43daba 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions;
import org.apache.sysds.lops.Append;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.fed.AggregateBinaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.AggregateTernaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.AggregateUnaryFEDInstruction;
@@ -52,6 +53,8 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "uak+" , FEDType.AggregateUnary );
String2FEDInstructionType.put( "uark+" , FEDType.AggregateUnary );
String2FEDInstructionType.put( "uack+" , FEDType.AggregateUnary );
+ String2FEDInstructionType.put( "uamax" , FEDType.AggregateUnary );
+ String2FEDInstructionType.put( "uamin" , FEDType.AggregateUnary );
String2FEDInstructionType.put( "uasqk+" , FEDType.AggregateUnary );
String2FEDInstructionType.put( "uarsqk+" , FEDType.AggregateUnary );
String2FEDInstructionType.put( "uacsqk+" , FEDType.AggregateUnary );
diff --git a/src/main/java/org/apache/sysds/utils/Explain.java b/src/main/java/org/apache/sysds/utils/Explain.java
index ba6fb7150e..c8e5902511 100644
--- a/src/main/java/org/apache/sysds/utils/Explain.java
+++ b/src/main/java/org/apache/sysds/utils/Explain.java
@@ -35,6 +35,7 @@ import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
+import org.apache.sysds.hops.fedplanner.MemoTable;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.DMLProgram;
@@ -78,6 +79,9 @@ public class Explain
private static final boolean SHOW_DATA_DEPENDENCIES = true;
private static final boolean SHOW_DATA_FLOW_PROPERTIES = true;
+ //federated execution plan alternatives
+ private static MemoTable MEMO_TABLE;
+
//different explain levels
public enum ExplainType {
NONE, // explain disabled
@@ -101,6 +105,14 @@ public class Explain
public int numChkpts = 0;
}
+ /**
+ * Store memo table for adding additional explain info regarding hops.
+ * @param memoTable to store in Explain
+ */
+ public static void setMemo(MemoTable memoTable){
+ MEMO_TABLE = memoTable;
+ }
+
//////////////
// public explain interface
@@ -600,6 +612,16 @@ public class Explain
if (hop.getExecType() != null)
sb.append(", " + hop.getExecType());
+ if ( MEMO_TABLE != null && MEMO_TABLE.containsHop(hop) ){
+ List<String> fedAlts = MEMO_TABLE.getFedOutAlternatives(hop);
+ if ( fedAlts != null ){
+ sb.append(" [ ");
+ for ( String fedAlt : fedAlts )
+ sb.append(fedAlt).append(" ");
+ sb.append("]");
+ }
+ }
+
sb.append('\n');
hop.setVisited();
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
index e9ab6b6ad0..1ba9966773 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
@@ -138,7 +138,8 @@ public class FederatedL2SVMPlanningTest extends AutomatedTestBase {
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] { "-stats", "-explain", "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ programArgs = new String[] { "-stats", "-explain", "hops", "-nvargs",
+ "X1=" + TestUtils.federatedAddress(port1, input("X1")),
"X2=" + TestUtils.federatedAddress(port2, input("X2")),
"Y=" + input("Y"), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
runTest(true, false, null, -1);