You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2022/05/13 13:16:52 UTC

[systemds] branch main updated: [SYSTEMDS-3018] Federated Planner Forced ExecType And FedOut Info

This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new aba1707852 [SYSTEMDS-3018] Federated Planner Forced ExecType And FedOut Info
aba1707852 is described below

commit aba1707852d546a9c46ada3f185824df063494bd
Author: sebwrede <sw...@know-center.at>
AuthorDate: Wed May 11 16:08:53 2022 +0200

    [SYSTEMDS-3018] Federated Planner Forced ExecType And FedOut Info
    
    Applying this commit will:
    1) Add Forced ExecType and Other Adjustments of ExecType
    2) Add FedOut Info to Explain Hops Output
    
    Closes #1612.
---
 .../java/org/apache/sysds/hops/AggUnaryOp.java     |  3 ++
 src/main/java/org/apache/sysds/hops/BinaryOp.java  |  4 +--
 src/main/java/org/apache/sysds/hops/Hop.java       | 18 ------------
 .../java/org/apache/sysds/hops/cost/HopRel.java    | 32 ++++++++++++++++------
 .../hops/fedplanner/FederatedPlannerCostbased.java | 14 +++++++++-
 .../apache/sysds/hops/fedplanner/MemoTable.java    | 15 ++++++++++
 .../runtime/instructions/FEDInstructionParser.java |  3 ++
 src/main/java/org/apache/sysds/utils/Explain.java  | 22 +++++++++++++++
 .../fedplanning/FederatedL2SVMPlanningTest.java    |  3 +-
 9 files changed, 82 insertions(+), 32 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index c461b69bac..23439b182e 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -608,6 +608,9 @@ public class AggUnaryOp extends MultiThreadedHop
 		ExecType et_input = input1.optFindExecType();
 		// Because ternary aggregate are not supported on GPU
 		et_input = et_input == ExecType.GPU ? ExecType.CP :  et_input;
+		// If forced ExecType is FED, it means that the federated planner updated the ExecType and
+		// execution may fail if ExecType is not FED
+		et_input = (getForcedExecType() == ExecType.FED) ? ExecType.FED : et_input;
 		
 		return new TernaryAggregate(in1, in2, in3, AggOp.SUM, 
 			OpOp2.MULT, _direction, getDataType(), ValueType.FP64, et_input, k);
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 791c3bdfbd..2346eeebfe 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -755,11 +755,9 @@ public class BinaryOp extends MultiThreadedHop {
 			checkAndSetInvalidCPDimsAndSize();
 		}
 
-		updateETFed();
-			
 		//spark-specific decision refinement (execute unary scalar w/ spark input and 
 		//single parent also in spark because it's likely cheap and reduces intermediates)
-		if( transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP
+		if( transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP && _etypeForced != ExecType.FED
 			&& getDataType().isMatrix() && (dt1.isScalar() || dt2.isScalar()) 
 			&& supportsMatrixScalarOperations()                          //scalar operations
 			&& !(getInput().get(dt1.isScalar()?1:0) instanceof DataOp)   //input is not checkpoint
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index 4ce9a4b90f..e1e4fcc8d4 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -909,24 +909,6 @@ public abstract class Hop implements ParseInfo {
 		return et;
 	}
 
-	/**
-	 * Update the execution type if input is federated.
-	 * This method only has an effect if FEDERATED_COMPILATION is activated.
-	 * Federated compilation is activated in OptimizerUtils.
-	 */
-	public void updateETFed() {
-		boolean localOut = hasLocalOutput();
-		boolean fedIn = getInput().stream().anyMatch(
-			in -> in.hasFederatedOutput() && !(in.prefetchActivated() && localOut));
-		if( isFederatedDataOp() || fedIn ){
-			setForcedExecType(ExecType.FED);
-			//TODO: Temporary solution where _etype is set directly
-			// since forcedExecType for BinaryOp may be overwritten
-			// if updateETFed is not called from optFindExecType.
-			_etype = ExecType.FED;
-		}
-	}
-
 	/**
 	 * Checks if ExecType is federated.
 	 * @return true if ExecType is federated
diff --git a/src/main/java/org/apache/sysds/hops/cost/HopRel.java b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
index 89a0f7cb50..70785950ca 100644
--- a/src/main/java/org/apache/sysds/hops/cost/HopRel.java
+++ b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.hops.cost;
 
 import org.apache.sysds.api.DMLException;
+import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.fedplanner.FTypes;
 import org.apache.sysds.hops.fedplanner.FTypes.FType;
@@ -43,8 +44,9 @@ import java.util.stream.Collectors;
 public class HopRel {
 	protected final Hop hopRef;
 	protected final FEDInstruction.FederatedOutput fedOut;
+	protected ExecType execType;
 	protected FTypes.FType fType;
-	protected final FederatedCost cost;
+	protected FederatedCost cost;
 	protected final Set<Long> costPointerSet = new HashSet<>();
 	protected List<Hop> inputHops;
 	protected List<HopRel> inputDependency = new ArrayList<>();
@@ -70,6 +72,13 @@ public class HopRel {
 		this(associatedHop, fedOut, null, hopRelMemo, inputs);
 	}
 
+	private HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, FType fType, List<Hop> inputs){
+		hopRef = associatedHop;
+		this.fedOut = fedOut;
+		this.fType = fType;
+		this.inputHops = inputs;
+	}
+
 	/**
 	 * Constructs a HopRel with input dependency and cost estimate based on entries in hopRelMemo.
 	 * @param associatedHop hop associated with this HopRel
@@ -79,21 +88,17 @@ public class HopRel {
 	 * @param inputs hop inputs which input dependencies and cost is based on
 	 */
 	public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, FType fType, MemoTable hopRelMemo, ArrayList<Hop> inputs){
-		hopRef = associatedHop;
-		this.fedOut = fedOut;
-		this.fType = fType;
-		this.inputHops = inputs;
+		this(associatedHop, fedOut, fType, inputs);
 		setInputDependency(hopRelMemo);
 		cost = FederatedCostEstimator.costEstimate(this, hopRelMemo);
+		setExecType();
 	}
 
 	public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, FType fType, MemoTable hopRelMemo, List<Hop> inputs, List<FType> inputDependency){
-		hopRef = associatedHop;
-		this.fedOut = fedOut;
-		this.inputHops = inputs;
-		this.fType = fType;
+		this(associatedHop, fedOut, fType, inputs);
 		setInputFTypeDependency(inputs, inputDependency, hopRelMemo);
 		cost = FederatedCostEstimator.costEstimate(this, hopRelMemo);
+		setExecType();
 	}
 
 	private void setInputFTypeDependency(List<Hop> inputs, List<FType> inputDependency, MemoTable hopRelMemo){
@@ -103,6 +108,11 @@ public class HopRel {
 		validateInputDependency();
 	}
 
+	private void setExecType(){
+		if ( inputDependency.stream().anyMatch(HopRel::hasFederatedOutput) )
+			execType = ExecType.FED;
+	}
+
 	/**
 	 * Adds hopID to set of hops pointing to this HopRel.
 	 * By storing the hopID it can later be determined if the cost
@@ -154,6 +164,10 @@ public class HopRel {
 		this.fType = fType;
 	}
 
+	public ExecType getExecType(){
+		return execType;
+	}
+
 	/**
 	 * Returns FOUT HopRel for given hop found in hopRelMemo or returns null if HopRel not found.
 	 * @param hop to look for in hopRelMemo
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
index a809d2bafd..e9a25206f8 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
@@ -30,6 +30,7 @@ import java.util.Set;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.hops.fedplanner.FTypes.FType;
 import org.apache.sysds.hops.DataOp;
@@ -53,6 +54,8 @@ import org.apache.sysds.parser.WhileStatement;
 import org.apache.sysds.parser.WhileStatementBlock;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+import org.apache.sysds.utils.Explain;
+import org.apache.sysds.utils.Explain.ExplainType;
 
 public class FederatedPlannerCostbased extends AFederatedPlanner {
 	private static final Log LOG = LogFactory.getLog(FederatedPlannerCostbased.class.getName());
@@ -77,6 +80,7 @@ public class FederatedPlannerCostbased extends AFederatedPlanner {
 		prog.updateRepetitionEstimates();
 		rewriteStatementBlocks(prog, prog.getStatementBlocks());
 		setFinalFedouts();
+		updateExplain();
 	}
 	
 	/**
@@ -215,7 +219,6 @@ public class FederatedPlannerCostbased extends AFederatedPlanner {
 			updateFederatedOutput(root, rootHopRel);
 			visitInputDependency(rootHopRel);
 		}
-		root.updateETFed();
 	}
 
 	/**
@@ -238,6 +241,7 @@ public class FederatedPlannerCostbased extends AFederatedPlanner {
 	private void updateFederatedOutput(Hop root, HopRel updateHopRel) {
 		root.setFederatedOutput(updateHopRel.getFederatedOutput());
 		root.setFederatedCost(updateHopRel.getCostObject());
+		root.setForcedExecType(updateHopRel.getExecType());
 		forceFixedFedOut(root);
 		LOG.trace("Updated fedOut to " + updateHopRel.getFederatedOutput() + " for hop "
 			+ root.getHopID() + " opcode: " + root.getOpString());
@@ -394,6 +398,14 @@ public class FederatedPlannerCostbased extends AFederatedPlanner {
 		}
 	}
 
+	/**
+	 * Add hopRelMemo to Explain class to get explain info related to federated enumeration.
+	 */
+	private void updateExplain(){
+		if (DMLScript.EXPLAIN == ExplainType.HOPS)
+			Explain.setMemo(hopRelMemo);
+	}
+
 	/**
 	 * Write HOP visit to debug log if debug is activated.
 	 * @param currentHop hop written to log
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
index 6b9da0f400..5b399bd499 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
@@ -23,6 +23,7 @@ import org.apache.sysds.api.DMLException;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.cost.HopRel;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
 
 import java.util.Comparator;
 import java.util.HashMap;
@@ -46,6 +47,20 @@ public class MemoTable {
 	 */
 	private final static Map<Long, List<HopRel>> hopRelMemo = new HashMap<>();
 
+	/**
+	 * Get list of strings representing the different
+	 * hopRel federated outputs related to root hop.
+	 * @param root for which HopRel fedouts are found
+	 * @return federated output values as strings
+	 */
+	public List<String> getFedOutAlternatives(Hop root){
+		if ( !containsHop(root) )
+			return null;
+		else return hopRelMemo.get(root.getHopID()).stream()
+			.map(HopRel::getFederatedOutput)
+			.map(FEDInstruction.FederatedOutput::name).collect(Collectors.toList());
+	}
+
 	/**
 	 * Get the HopRel with minimum cost for given root hop
 	 * @param root hop for which minimum cost HopRel is found
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 2fde0a0fbc..58ab43daba 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions;
 
 import org.apache.sysds.lops.Append;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.cp.CPInstruction;
 import org.apache.sysds.runtime.instructions.fed.AggregateBinaryFEDInstruction;
 import org.apache.sysds.runtime.instructions.fed.AggregateTernaryFEDInstruction;
 import org.apache.sysds.runtime.instructions.fed.AggregateUnaryFEDInstruction;
@@ -52,6 +53,8 @@ public class FEDInstructionParser extends InstructionParser
 		String2FEDInstructionType.put( "uak+"    , FEDType.AggregateUnary );
 		String2FEDInstructionType.put( "uark+"   , FEDType.AggregateUnary );
 		String2FEDInstructionType.put( "uack+"   , FEDType.AggregateUnary );
+		String2FEDInstructionType.put( "uamax"   , FEDType.AggregateUnary );
+		String2FEDInstructionType.put( "uamin"   , FEDType.AggregateUnary );
 		String2FEDInstructionType.put( "uasqk+"  , FEDType.AggregateUnary );
 		String2FEDInstructionType.put( "uarsqk+" , FEDType.AggregateUnary );
 		String2FEDInstructionType.put( "uacsqk+" , FEDType.AggregateUnary );
diff --git a/src/main/java/org/apache/sysds/utils/Explain.java b/src/main/java/org/apache/sysds/utils/Explain.java
index ba6fb7150e..c8e5902511 100644
--- a/src/main/java/org/apache/sysds/utils/Explain.java
+++ b/src/main/java/org/apache/sysds/utils/Explain.java
@@ -35,6 +35,7 @@ import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.codegen.cplan.CNode;
 import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
 import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
+import org.apache.sysds.hops.fedplanner.MemoTable;
 import org.apache.sysds.hops.ipa.FunctionCallGraph;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.parser.DMLProgram;
@@ -78,6 +79,9 @@ public class Explain
 	private static final boolean SHOW_DATA_DEPENDENCIES     = true;
 	private static final boolean SHOW_DATA_FLOW_PROPERTIES  = true;
 
+	//federated execution plan alternatives
+	private static MemoTable MEMO_TABLE;
+
 	//different explain levels
 	public enum ExplainType {
 		NONE, 	  // explain disabled
@@ -101,6 +105,14 @@ public class Explain
 		public int numChkpts = 0;
 	}
 
+	/**
+	 * Store memo table for adding additional explain info regarding hops.
+	 * @param memoTable to store in Explain
+	 */
+	public static void setMemo(MemoTable memoTable){
+		MEMO_TABLE = memoTable;
+	}
+
 	//////////////
 	// public explain interface
 
@@ -600,6 +612,16 @@ public class Explain
 		if (hop.getExecType() != null)
 			sb.append(", " + hop.getExecType());
 
+		if ( MEMO_TABLE != null && MEMO_TABLE.containsHop(hop) ){
+			List<String> fedAlts = MEMO_TABLE.getFedOutAlternatives(hop);
+			if ( fedAlts != null ){
+				sb.append(" [ ");
+				for ( String fedAlt : fedAlts )
+					sb.append(fedAlt).append(" ");
+				sb.append("]");
+			}
+		}
+
 		sb.append('\n');
 
 		hop.setVisited();
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
index e9ab6b6ad0..1ba9966773 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
@@ -138,7 +138,8 @@ public class FederatedL2SVMPlanningTest extends AutomatedTestBase {
 
 			// Run actual dml script with federated matrix
 			fullDMLScriptName = HOME + TEST_NAME + ".dml";
-			programArgs = new String[] { "-stats", "-explain", "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+			programArgs = new String[] { "-stats", "-explain", "hops", "-nvargs",
+				"X1=" + TestUtils.federatedAddress(port1, input("X1")),
 				"X2=" + TestUtils.federatedAddress(port2, input("X2")),
 				"Y=" + input("Y"), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
 			runTest(true, false, null, -1);