You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2021/10/19 15:07:37 UTC
[systemds] branch master updated: [SYSTEMDS-3018] Federated Reorg
Operation FedOut Compilation
This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 2026cff [SYSTEMDS-3018] Federated Reorg Operation FedOut Compilation
2026cff is described below
commit 2026cfff97fd992c75b6b56ac01d8d199f4f9db3
Author: sebwrede <sw...@know-center.at>
AuthorDate: Mon Oct 11 18:30:28 2021 +0200
[SYSTEMDS-3018] Federated Reorg Operation FedOut Compilation
This commit ensures that Reorg operations rdiag and rev are compiled with the federated output flag FOUT/LOUT.
Additionally, it removes rshape and rsort from the FEDInstructionParser since the federated parsing of these
Reorg types are not supported yet.
Closes #1414.
---
.../hops/rewrite/IPAPassRewriteFederatedPlan.java | 1 -
src/main/java/org/apache/sysds/lops/Lop.java | 6 +-
.../runtime/instructions/FEDInstructionParser.java | 3 +-
.../runtime/instructions/InstructionUtils.java | 25 ++++++-
.../instructions/fed/ReorgFEDInstruction.java | 81 +++++++++++++---------
.../instructions/fed/UnaryFEDInstruction.java | 14 ++++
.../federated/primitives/FederatedRdiagTest.java | 17 +++++
.../federated/primitives/FederatedRevTest.java | 18 +++++
8 files changed, 125 insertions(+), 40 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
index cbc21cf..377ebb1 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
@@ -252,7 +252,6 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
if ( hopRels.isEmpty() )
hopRels.add(new HopRel(currentHop, FEDInstruction.FederatedOutput.NONE, hopRelMemo));
hopRelMemo.put(currentHop.getHopID(), hopRels);
- currentHop.setVisited();
}
/**
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java
index e014d3c..7da091f 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -117,9 +117,9 @@ public abstract class Lop
protected PrivacyConstraint privacyConstraint;
/**
- * Boolean defining if the output of the operation should be federated.
- * If it is true, the output should be kept at federated sites.
- * If it is false, the output should be retrieved by the coordinator.
+ * Enum defining if the output of the operation should be forced federated, forced local or neither.
+ * If it is FOUT, the output should be kept at federated sites.
+ * If it is LOUT, the output should be retrieved by the coordinator.
*/
protected FederatedOutput _fedOutput = null;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 755287a..8000da7 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -63,8 +63,9 @@ public class FEDInstructionParser extends InstructionParser
// Reorg Instruction Opcodes (repositioning of existing values)
String2FEDInstructionType.put( "r'" , FEDType.Reorg );
String2FEDInstructionType.put( "rdiag" , FEDType.Reorg );
- String2FEDInstructionType.put( "rshape" , FEDType.Reorg );
String2FEDInstructionType.put( "rev" , FEDType.Reorg );
+ //String2FEDInstructionType.put( "rshape" , FEDType.Reorg ); Not supported by ReorgFEDInstruction parser!
+ //String2FEDInstructionType.put( "rsort" , FEDType.Reorg ); Not supported by ReorgFEDInstruction parser!
// Ternary Instruction Opcodes
String2FEDInstructionType.put( "+*" , FEDType.Ternary);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 246ed87..9991edf 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -85,6 +85,7 @@ import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction.GPUINSTRUCTION_TYPE;
import org.apache.sysds.runtime.instructions.spark.SPInstruction.SPType;
import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
@@ -1144,8 +1145,30 @@ public class InstructionUtils
return linst;
}
+ /**
+ * Removes federated output flag from the end of the instruction string if the flag is present.
+ * @param linst instruction string
+ * @return instruction string with no federated output flag
+ */
public static String removeFEDOutputFlag(String linst){
- return linst.substring(0, linst.lastIndexOf(Lop.OPERAND_DELIMITOR));
+ int lastOperandStartIndex = linst.lastIndexOf(Lop.OPERAND_DELIMITOR);
+ String lastOperand = linst.substring(lastOperandStartIndex);
+ if ( containsFEDOutputFlag(lastOperand) )
+ return linst.substring(0, lastOperandStartIndex);
+ else return linst;
+ }
+
+ /**
+ * Checks whether the given operand string contains a federated output flag
+ * @param operandString which is checked for federated output flag
+ * @return true if the given operand string contains a federated output flag
+ */
+ private static boolean containsFEDOutputFlag(String operandString){
+ for (FederatedOutput fedOutput : FederatedOutput.values()){
+ if ( operandString.contains(fedOutput.name()) )
+ return true;
+ }
+ return false;
}
private static String replaceOperand(String linst, CPOperand oldOperand, String newOperandName){
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index c32b15b..4202498 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -54,8 +54,6 @@ import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
public class ReorgFEDInstruction extends UnaryFEDInstruction {
- @SuppressWarnings("unused")
- private static boolean fedoutFlagInString = false;
public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut);
@@ -71,23 +69,25 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
+ FederatedOutput fedOut;
if ( opcode.equalsIgnoreCase("r'") ) {
InstructionUtils.checkNumFields(str, 2, 3, 4);
in.split(parts[1]);
out.split(parts[2]);
int k = str.startsWith(Types.ExecMode.SPARK.name()) ? 0 : Integer.parseInt(parts[3]);
- FederatedOutput fedOut = str.startsWith(Types.ExecMode.SPARK.name()) ? FederatedOutput.valueOf(parts[3]) :
- FederatedOutput.valueOf(parts[4]);
+ fedOut = str.startsWith(Types.ExecMode.SPARK.name()) ?
+ FederatedOutput.valueOf(parts[3]) : FederatedOutput.valueOf(parts[4]);
return new ReorgFEDInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str, fedOut);
}
else if ( opcode.equalsIgnoreCase("rdiag") ) {
parseUnaryInstruction(str, in, out); //max 2 operands
- return new ReorgFEDInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
+ fedOut = parseFedOutFlag(str, 3);
+ return new ReorgFEDInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str, fedOut);
}
else if ( opcode.equalsIgnoreCase("rev") ) {
- fedoutFlagInString = parts.length > 3;
parseUnaryInstruction(str, in, out); //max 2 operands
- return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
+ fedOut = parseFedOutFlag(str, 3);
+ return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str, fedOut);
}
else {
throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: "+opcode);
@@ -117,7 +117,6 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
mo1.getFedMapping().execute(getTID(), true, fr, fr1);
if (_fedOut != null && !_fedOut.isForcedLocal()){
- mo1.getFedMapping().execute(getTID(), true, fr1);
//drive output federated mapping
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz());
@@ -146,10 +145,7 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
out.getDataCharacteristics().set(mo1.getNumRows(), mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz());
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
- if ( _fedOut != null && _fedOut.isForcedLocal() ){
- out.acquireReadAndRelease();
- out.getFedMapping().cleanup(getTID(), fr1.getID());
- }
+ optionalForceLocal(out);
}
else if (instOpcode.equals("rdiag")) {
RdiagResult result;
@@ -160,24 +156,7 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
result = rdiagM2V(mo1, r_op);
}
- FederationMap diagFedMap = result.getFedMap();
- Map<FederatedRange, int[]> dcs = result.getDcs();
-
- //update fed ranges
- for(int i = 0; i < diagFedMap.getFederatedRanges().length; i++) {
- int[] newRange = dcs.get(diagFedMap.getFederatedRanges()[i]);
-
- diagFedMap.getFederatedRanges()[i].setBeginDim(0,
- (diagFedMap.getFederatedRanges()[i].getBeginDims()[0] == 0 ||
- i == 0) ? 0 : diagFedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
- diagFedMap.getFederatedRanges()[i].setEndDim(0,
- diagFedMap.getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
- diagFedMap.getFederatedRanges()[i].setBeginDim(1,
- (diagFedMap.getFederatedRanges()[i].getBeginDims()[1] == 0 ||
- i == 0) ? 0 : diagFedMap.getFederatedRanges()[i - 1].getEndDims()[1]);
- diagFedMap.getFederatedRanges()[i].setEndDim(1,
- diagFedMap.getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
- }
+ FederationMap diagFedMap = updateFedRanges(result);
//update output mapping and data characteristics
MatrixObject rdiag = ec.getMatrixObject(output);
@@ -185,10 +164,44 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
.set(diagFedMap.getMaxIndexInRange(0), diagFedMap.getMaxIndexInRange(1),
(int) mo1.getBlocksize());
rdiag.setFedMapping(diagFedMap);
- if ( _fedOut != null && _fedOut.isForcedLocal() ){
- rdiag.acquireReadAndRelease();
- rdiag.getFedMapping().cleanup(getTID(), rdiag.getFedMapping().getID());
- }
+ optionalForceLocal(rdiag);
+ }
+ }
+
+ /**
+ * Update the federated ranges of result and return the updated federation map.
+ * @param result RdiagResult for which the fedmap is updated
+ * @return updated federation map
+ */
+ private FederationMap updateFedRanges(RdiagResult result){
+ FederationMap diagFedMap = result.getFedMap();
+ Map<FederatedRange, int[]> dcs = result.getDcs();
+
+ for(int i = 0; i < diagFedMap.getFederatedRanges().length; i++) {
+ int[] newRange = dcs.get(diagFedMap.getFederatedRanges()[i]);
+
+ diagFedMap.getFederatedRanges()[i].setBeginDim(0,
+ (diagFedMap.getFederatedRanges()[i].getBeginDims()[0] == 0 ||
+ i == 0) ? 0 : diagFedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
+ diagFedMap.getFederatedRanges()[i].setEndDim(0,
+ diagFedMap.getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
+ diagFedMap.getFederatedRanges()[i].setBeginDim(1,
+ (diagFedMap.getFederatedRanges()[i].getBeginDims()[1] == 0 ||
+ i == 0) ? 0 : diagFedMap.getFederatedRanges()[i - 1].getEndDims()[1]);
+ diagFedMap.getFederatedRanges()[i].setEndDim(1,
+ diagFedMap.getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
+ }
+ return diagFedMap;
+ }
+
+ /**
+ * If federated output is forced local, the output will be retrieved and removed from federated workers.
+ * @param outputMatrixObject which will be retrieved and removed from federated workers
+ */
+ private void optionalForceLocal(MatrixObject outputMatrixObject){
+ if ( _fedOut != null && _fedOut.isForcedLocal() ){
+ outputMatrixObject.acquireReadAndRelease();
+ outputMatrixObject.getFedMapping().cleanup(getTID(), outputMatrixObject.getFedMapping().getID());
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
index 0ae3178..dea0acf 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
@@ -110,4 +110,18 @@ public abstract class UnaryFEDInstruction extends ComputationFEDInstruction {
out.split(parts[parts.length - 2]);
return opcode;
}
+
+ /**
+ * Parse and return federated output flag from given instr string at given position.
+ * If the position given is greater than the length of the instruction, FederatedOutput.NONE is returned.
+ * @param instr instruction string to be parsed
+ * @param position of federated output flag
+ * @return parsed federated output flag or FederatedOutput.NONE
+ */
+ static FederatedOutput parseFedOutFlag(String instr, int position){
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
+ if ( parts.length > position )
+ return FederatedOutput.valueOf(parts[position]);
+ else return FederatedOutput.NONE;
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
index cda9966..e4e7a88 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
@@ -24,6 +24,7 @@ import java.util.Collection;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
@@ -69,7 +70,21 @@ public class FederatedRdiagTest extends AutomatedTestBase {
@Test
public void federatedRdiagSP() { federatedRdiag(Types.ExecMode.SPARK); }
+ @Test
+ public void federatedCompilationRDiagCP(){
+ federatedRdiag(Types.ExecMode.SINGLE_NODE, true);
+ }
+
+ @Test
+ public void federatedCompilationRdiagSP(){
+ federatedRdiag(Types.ExecMode.SPARK, true);
+ }
+
public void federatedRdiag(Types.ExecMode execMode) {
+ federatedRdiag(execMode, false);
+ }
+
+ public void federatedRdiag(Types.ExecMode execMode, boolean activateFedCompilation) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
@@ -111,6 +126,7 @@ public class FederatedRdiagTest extends AutomatedTestBase {
input("X1"), input("X2"), input("X3"), input("X4"), expected("S")};
runTest(null);
+ OptimizerUtils.FEDERATED_COMPILATION = activateFedCompilation;
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -139,5 +155,6 @@ public class FederatedRdiagTest extends AutomatedTestBase {
TestUtils.shutdownThreads(t1, t2, t3, t4);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ OptimizerUtils.FEDERATED_COMPILATION = false;
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
index 847f351..66f9c2f 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
@@ -23,7 +23,9 @@ import java.util.Arrays;
import java.util.Collection;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
@@ -77,7 +79,21 @@ public class FederatedRevTest extends AutomatedTestBase {
runRevTest(ExecMode.SPARK);
}
+ @Test
+ public void federatedCompilationRevCP(){
+ runRevTest(Types.ExecMode.SINGLE_NODE, true);
+ }
+
+ @Test
+ public void federatedCompilationRevSP(){
+ runRevTest(Types.ExecMode.SPARK, true);
+ }
+
private void runRevTest(ExecMode execMode) {
+ runRevTest(execMode, false);
+ }
+
+ private void runRevTest(ExecMode execMode, boolean activateFedCompilation) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -135,6 +151,7 @@ public class FederatedRevTest extends AutomatedTestBase {
runTest(null);
+ OptimizerUtils.FEDERATED_COMPILATION = activateFedCompilation;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
@@ -160,6 +177,7 @@ public class FederatedRevTest extends AutomatedTestBase {
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ OptimizerUtils.FEDERATED_COMPILATION = false;
}
}