You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2021/10/19 15:07:37 UTC

[systemds] branch master updated: [SYSTEMDS-3018] Federated Reorg Operation FedOut Compilation

This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 2026cff  [SYSTEMDS-3018] Federated Reorg Operation FedOut Compilation
2026cff is described below

commit 2026cfff97fd992c75b6b56ac01d8d199f4f9db3
Author: sebwrede <sw...@know-center.at>
AuthorDate: Mon Oct 11 18:30:28 2021 +0200

    [SYSTEMDS-3018] Federated Reorg Operation FedOut Compilation
    
    This commit ensures that Reorg operations rdiag and rev are compiled with the federated output flag FOUT/LOUT.
    Additionally, it removes rshape and rsort from the FEDInstructionParser since the federated parsing of these
    Reorg types are not supported yet.
    Closes #1414.
---
 .../hops/rewrite/IPAPassRewriteFederatedPlan.java  |  1 -
 src/main/java/org/apache/sysds/lops/Lop.java       |  6 +-
 .../runtime/instructions/FEDInstructionParser.java |  3 +-
 .../runtime/instructions/InstructionUtils.java     | 25 ++++++-
 .../instructions/fed/ReorgFEDInstruction.java      | 81 +++++++++++++---------
 .../instructions/fed/UnaryFEDInstruction.java      | 14 ++++
 .../federated/primitives/FederatedRdiagTest.java   | 17 +++++
 .../federated/primitives/FederatedRevTest.java     | 18 +++++
 8 files changed, 125 insertions(+), 40 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
index cbc21cf..377ebb1 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
@@ -252,7 +252,6 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
 		if ( hopRels.isEmpty() )
 			hopRels.add(new HopRel(currentHop, FEDInstruction.FederatedOutput.NONE, hopRelMemo));
 		hopRelMemo.put(currentHop.getHopID(), hopRels);
-		currentHop.setVisited();
 	}
 
 	/**
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java
index e014d3c..7da091f 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -117,9 +117,9 @@ public abstract class Lop
 	protected PrivacyConstraint privacyConstraint;
 
 	/**
-	 * Boolean defining if the output of the operation should be federated.
-	 * If it is true, the output should be kept at federated sites.
-	 * If it is false, the output should be retrieved by the coordinator.
+	 * Enum defining if the output of the operation should be forced federated, forced local or neither.
+	 * If it is FOUT, the output should be kept at federated sites.
+	 * If it is LOUT, the output should be retrieved by the coordinator.
 	 */
 	protected FederatedOutput _fedOutput = null;
 	
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 755287a..8000da7 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -63,8 +63,9 @@ public class FEDInstructionParser extends InstructionParser
 		// Reorg Instruction Opcodes (repositioning of existing values)
 		String2FEDInstructionType.put( "r'"     , FEDType.Reorg );
 		String2FEDInstructionType.put( "rdiag"  , FEDType.Reorg );
-		String2FEDInstructionType.put( "rshape" , FEDType.Reorg );
 		String2FEDInstructionType.put( "rev"    , FEDType.Reorg );
+		//String2FEDInstructionType.put( "rshape" , FEDType.Reorg ); Not supported by ReorgFEDInstruction parser!
+		//String2FEDInstructionType.put( "rsort"  , FEDType.Reorg ); Not supported by ReorgFEDInstruction parser!
 
 		// Ternary Instruction Opcodes
 		String2FEDInstructionType.put( "+*" , FEDType.Ternary);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 246ed87..9991edf 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -85,6 +85,7 @@ import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
 import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
 import org.apache.sysds.runtime.instructions.gpu.GPUInstruction.GPUINSTRUCTION_TYPE;
 import org.apache.sysds.runtime.instructions.spark.SPInstruction.SPType;
 import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
@@ -1144,8 +1145,30 @@ public class InstructionUtils
 		return linst;
 	}
 
+	/**
+	 * Removes federated output flag from the end of the instruction string if the flag is present.
+	 * @param linst instruction string
+	 * @return instruction string with no federated output flag
+	 */
 	public static String removeFEDOutputFlag(String linst){
-		return linst.substring(0, linst.lastIndexOf(Lop.OPERAND_DELIMITOR));
+		int lastOperandStartIndex = linst.lastIndexOf(Lop.OPERAND_DELIMITOR);
+		String lastOperand = linst.substring(lastOperandStartIndex);
+		if ( containsFEDOutputFlag(lastOperand) )
+			return linst.substring(0, lastOperandStartIndex);
+		else return linst;
+	}
+
+	/**
+	 * Checks whether the given operand string contains a federated output flag
+	 * @param operandString which is checked for federated output flag
+	 * @return true if the given operand string contains a federated output flag
+	 */
+	private static boolean containsFEDOutputFlag(String operandString){
+		for (FederatedOutput fedOutput : FederatedOutput.values()){
+			if ( operandString.contains(fedOutput.name()) )
+				return true;
+		}
+		return false;
 	}
 
 	private static String replaceOperand(String linst, CPOperand oldOperand, String newOperandName){
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index c32b15b..4202498 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -54,8 +54,6 @@ import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 
 public class ReorgFEDInstruction extends UnaryFEDInstruction {
-	@SuppressWarnings("unused")
-	private static boolean fedoutFlagInString = false;
 
 	public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
 		super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut);
@@ -71,23 +69,25 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
 
 		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
 		String opcode = parts[0];
+		FederatedOutput fedOut;
 		if ( opcode.equalsIgnoreCase("r'") ) {
 			InstructionUtils.checkNumFields(str, 2, 3, 4);
 			in.split(parts[1]);
 			out.split(parts[2]);
 			int k = str.startsWith(Types.ExecMode.SPARK.name()) ? 0 : Integer.parseInt(parts[3]);
-			FederatedOutput fedOut = str.startsWith(Types.ExecMode.SPARK.name()) ?  FederatedOutput.valueOf(parts[3]) :
-				FederatedOutput.valueOf(parts[4]);
+			fedOut = str.startsWith(Types.ExecMode.SPARK.name()) ?
+				FederatedOutput.valueOf(parts[3]) : FederatedOutput.valueOf(parts[4]);
 			return new ReorgFEDInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str, fedOut);
 		}
 		else if ( opcode.equalsIgnoreCase("rdiag") ) {
 			parseUnaryInstruction(str, in, out); //max 2 operands
-			return new ReorgFEDInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
+			fedOut = parseFedOutFlag(str, 3);
+			return new ReorgFEDInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str, fedOut);
 		}
 		else if ( opcode.equalsIgnoreCase("rev") ) {
-			fedoutFlagInString = parts.length > 3;
 			parseUnaryInstruction(str, in, out); //max 2 operands
-			return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
+			fedOut = parseFedOutFlag(str, 3);
+			return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str, fedOut);
 		}
 		else {
 			throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: "+opcode);
@@ -117,7 +117,6 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
 			mo1.getFedMapping().execute(getTID(), true, fr, fr1);
 
 			if (_fedOut != null && !_fedOut.isForcedLocal()){
-				mo1.getFedMapping().execute(getTID(), true, fr1);
 				//drive output federated mapping
 				MatrixObject out = ec.getMatrixObject(output);
 				out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz());
@@ -146,10 +145,7 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
 			out.getDataCharacteristics().set(mo1.getNumRows(), mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz());
 			out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
 
-			if ( _fedOut != null && _fedOut.isForcedLocal() ){
-				out.acquireReadAndRelease();
-				out.getFedMapping().cleanup(getTID(), fr1.getID());
-			}
+			optionalForceLocal(out);
 		}
 		else if (instOpcode.equals("rdiag")) {
 			RdiagResult result;
@@ -160,24 +156,7 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
 				result = rdiagM2V(mo1, r_op);
 			}
 
-			FederationMap diagFedMap = result.getFedMap();
-			Map<FederatedRange, int[]> dcs = result.getDcs();
-
-			//update fed ranges
-			for(int i = 0; i < diagFedMap.getFederatedRanges().length; i++) {
-				int[] newRange = dcs.get(diagFedMap.getFederatedRanges()[i]);
-
-				diagFedMap.getFederatedRanges()[i].setBeginDim(0,
-					(diagFedMap.getFederatedRanges()[i].getBeginDims()[0] == 0 ||
-						i == 0) ? 0 : diagFedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
-				diagFedMap.getFederatedRanges()[i].setEndDim(0,
-					diagFedMap.getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
-				diagFedMap.getFederatedRanges()[i].setBeginDim(1,
-					(diagFedMap.getFederatedRanges()[i].getBeginDims()[1] == 0 ||
-						i == 0) ? 0 : diagFedMap.getFederatedRanges()[i - 1].getEndDims()[1]);
-				diagFedMap.getFederatedRanges()[i].setEndDim(1,
-					diagFedMap.getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
-			}
+			FederationMap diagFedMap = updateFedRanges(result);
 
 			//update output mapping and data characteristics
 			MatrixObject rdiag = ec.getMatrixObject(output);
@@ -185,10 +164,44 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
 				.set(diagFedMap.getMaxIndexInRange(0), diagFedMap.getMaxIndexInRange(1),
 					(int) mo1.getBlocksize());
 			rdiag.setFedMapping(diagFedMap);
-			if ( _fedOut != null && _fedOut.isForcedLocal() ){
-				rdiag.acquireReadAndRelease();
-				rdiag.getFedMapping().cleanup(getTID(), rdiag.getFedMapping().getID());
-			}
+			optionalForceLocal(rdiag);
+		}
+	}
+
+	/**
+	 * Update the federated ranges of result and return the updated federation map.
+	 * @param result RdiagResult for which the fedmap is updated
+	 * @return updated federation map
+	 */
+	private FederationMap updateFedRanges(RdiagResult result){
+		FederationMap diagFedMap = result.getFedMap();
+		Map<FederatedRange, int[]> dcs = result.getDcs();
+
+		for(int i = 0; i < diagFedMap.getFederatedRanges().length; i++) {
+			int[] newRange = dcs.get(diagFedMap.getFederatedRanges()[i]);
+
+			diagFedMap.getFederatedRanges()[i].setBeginDim(0,
+				(diagFedMap.getFederatedRanges()[i].getBeginDims()[0] == 0 ||
+					i == 0) ? 0 : diagFedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
+			diagFedMap.getFederatedRanges()[i].setEndDim(0,
+				diagFedMap.getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
+			diagFedMap.getFederatedRanges()[i].setBeginDim(1,
+				(diagFedMap.getFederatedRanges()[i].getBeginDims()[1] == 0 ||
+					i == 0) ? 0 : diagFedMap.getFederatedRanges()[i - 1].getEndDims()[1]);
+			diagFedMap.getFederatedRanges()[i].setEndDim(1,
+				diagFedMap.getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
+		}
+		return diagFedMap;
+	}
+
+	/**
+	 * If federated output is forced local, the output will be retrieved and removed from federated workers.
+	 * @param outputMatrixObject which will be retrieved and removed from federated workers
+	 */
+	private void optionalForceLocal(MatrixObject outputMatrixObject){
+		if ( _fedOut != null && _fedOut.isForcedLocal() ){
+			outputMatrixObject.acquireReadAndRelease();
+			outputMatrixObject.getFedMapping().cleanup(getTID(), outputMatrixObject.getFedMapping().getID());
 		}
 	}
 
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
index 0ae3178..dea0acf 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
@@ -110,4 +110,18 @@ public abstract class UnaryFEDInstruction extends ComputationFEDInstruction {
 		out.split(parts[parts.length - 2]);
 		return opcode;
 	}
+
+	/**
+	 * Parse and return federated output flag from given instr string at given position.
+	 * If the position given is greater than the length of the instruction, FederatedOutput.NONE is returned.
+	 * @param instr instruction string to be parsed
+	 * @param position of federated output flag
+	 * @return parsed federated output flag or FederatedOutput.NONE
+	 */
+	static FederatedOutput parseFedOutFlag(String instr, int position){
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
+		if ( parts.length > position )
+			return FederatedOutput.valueOf(parts[position]);
+		else return FederatedOutput.NONE;
+	}
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
index cda9966..e4e7a88 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
@@ -24,6 +24,7 @@ import java.util.Collection;
 
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.HDFSTool;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -69,7 +70,21 @@ public class FederatedRdiagTest extends AutomatedTestBase {
 	@Test
 	public void federatedRdiagSP() { federatedRdiag(Types.ExecMode.SPARK); }
 
+	@Test
+	public void federatedCompilationRDiagCP(){
+		federatedRdiag(Types.ExecMode.SINGLE_NODE, true);
+	}
+
+	@Test
+	public void federatedCompilationRdiagSP(){
+		federatedRdiag(Types.ExecMode.SPARK, true);
+	}
+
 	public void federatedRdiag(Types.ExecMode execMode) {
+		federatedRdiag(execMode, false);
+	}
+
+	public void federatedRdiag(Types.ExecMode execMode, boolean activateFedCompilation) {
 		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
 		Types.ExecMode platformOld = rtplatform;
 
@@ -111,6 +126,7 @@ public class FederatedRdiagTest extends AutomatedTestBase {
 			input("X1"), input("X2"), input("X3"), input("X4"), expected("S")};
 		runTest(null);
 
+		OptimizerUtils.FEDERATED_COMPILATION = activateFedCompilation;
 		TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
 		loadTestConfiguration(config);
 
@@ -139,5 +155,6 @@ public class FederatedRdiagTest extends AutomatedTestBase {
 		TestUtils.shutdownThreads(t1, t2, t3, t4);
 		rtplatform = platformOld;
 		DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+		OptimizerUtils.FEDERATED_COMPILATION = false;
 	}
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
index 847f351..66f9c2f 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
@@ -23,7 +23,9 @@ import java.util.Arrays;
 import java.util.Collection;
 
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.HDFSTool;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -77,7 +79,21 @@ public class FederatedRevTest extends AutomatedTestBase {
 		runRevTest(ExecMode.SPARK);
 	}
 
+	@Test
+	public void federatedCompilationRevCP(){
+		runRevTest(Types.ExecMode.SINGLE_NODE, true);
+	}
+
+	@Test
+	public void federatedCompilationRevSP(){
+		runRevTest(Types.ExecMode.SPARK, true);
+	}
+
 	private void runRevTest(ExecMode execMode) {
+		runRevTest(execMode, false);
+	}
+
+	private void runRevTest(ExecMode execMode, boolean activateFedCompilation) {
 		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
 		ExecMode platformOld = rtplatform;
 
@@ -135,6 +151,7 @@ public class FederatedRevTest extends AutomatedTestBase {
 
 		runTest(null);
 
+		OptimizerUtils.FEDERATED_COMPILATION = activateFedCompilation;
 		fullDMLScriptName = HOME + TEST_NAME + ".dml";
 		programArgs = new String[] {"-stats", "100", "-nvargs",
 			"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
@@ -160,6 +177,7 @@ public class FederatedRevTest extends AutomatedTestBase {
 
 		rtplatform = platformOld;
 		DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+		OptimizerUtils.FEDERATED_COMPILATION = false;
 
 	}
 }