You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2022/01/23 22:20:07 UTC

[systemds] branch main updated: [SYSTEMDS-3278] Fix federated unary aggregate, duplicates in fedinit

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new f00c6c7  [SYSTEMDS-3278] Fix federated unary aggregate, duplicates in fedinit
f00c6c7 is described below

commit f00c6c70023826c5f27876edaa59e4b82853ec5b
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sun Jan 23 23:19:42 2022 +0100

    [SYSTEMDS-3278] Fix federated unary aggregate, duplicates in fedinit
    
    This patch addresses two severe issues that have been introduced or
    identified in recent history:
    
    * The unary aggregate incorrectly called cleanup on get (for local out),
    which causes issues if a federated instruction is called on the output
    as well (besides the CP output data)
    
    * The parallel event loop revealed shortcomings with duplicated
    addresses in federated data. Concurrent requests and fixed variable
    names across requests cause an overwrite of intermediates in this case.
    This patch adds a warning on federated init and fixes the incorrect
    test.
---
 .../federated/FederatedLookupTable.java            |  2 +-
 .../fed/AggregateUnaryFEDInstruction.java          | 21 +++---
 .../instructions/fed/InitFEDInstruction.java       | 13 ++++
 .../federated/primitives/FederatedSumTest.java     | 86 +++++++++++-----------
 .../functions/federated/FederatedSumTest.dml       |  2 +-
 .../federated/FederatedSumTestReference.dml        |  2 +-
 6 files changed, 70 insertions(+), 56 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
index 63defe4..55ab971 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
@@ -31,7 +31,7 @@ import org.apache.sysds.api.DMLScript;
  * ExecutionContextMap (ECM) so that every coordinator can address federated
  * variables with its own local sequential variable IDs. Therefore, the IDs
  * among different coordinators do not have to be distinct, as every
- * coordinator works with a seperate ECM at the FederatedWorker.
+ * coordinator works with a separate ECM at the FederatedWorker.
  */
 public class FederatedLookupTable {
 	// the NOHOST constant is needed for creating FederatedLocalData where there
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index d329f44..88a066a 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -201,10 +201,9 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
 		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
 			new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()}, true);
 		FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
-		FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
 
 		//execute federated commands and cleanups
-		Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2, fr3);
+		Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2);
 		if( output.isScalar() )
 			ec.setVariable(output.getName(), FederationUtils.aggScalar(aggUOptr, tmp, map));
 		else
@@ -250,18 +249,20 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
 			FederatedRequest meanFr1 =  FederationUtils.callInstruction(meanInstr, output, id,
 				new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()}, isSpark ? ExecType.SPARK : ExecType.CP, isSpark);
 			FederatedRequest meanFr2 = new FederatedRequest(RequestType.GET_VAR, meanFr1.getID());
-			FederatedRequest meanFr3 = map.cleanup(getTID(), meanFr1.getID());
-			meanTmp = map.execute(getTID(), isSpark ? new FederatedRequest[] {tmpRequest, meanFr1, meanFr2, meanFr3} : new FederatedRequest[] {meanFr1, meanFr2, meanFr3});
+			meanTmp = map.execute(getTID(), isSpark ?
+				new FederatedRequest[] {tmpRequest, meanFr1, meanFr2} :
+				new FederatedRequest[] {meanFr1, meanFr2});
 		}
 
 		//create federated commands for aggregation
 		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, id,
 			new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()}, isSpark ? ExecType.SPARK : ExecType.CP, isSpark);
 		FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
-		FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
 		
 		//execute federated commands and cleanups
-		Future<FederatedResponse>[] tmp = map.execute(getTID(), isSpark ? new FederatedRequest[] {tmpRequest,  fr1, fr2, fr3} : new FederatedRequest[] { fr1, fr2, fr3});
+		Future<FederatedResponse>[] tmp = map.execute(getTID(), isSpark ?
+			new FederatedRequest[] {tmpRequest, fr1, fr2} :
+			new FederatedRequest[] { fr1, fr2});
 		if( output.isScalar() )
 			ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp, meanTmp, map));
 		else
@@ -281,7 +282,7 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
 		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output, id,
 			new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()}, ExecType.SPARK, true);
 
-		map.execute(getTID(), fr1, fr2);
+		map.execute(getTID(), true, fr1, fr2);
 		// derive new fed mapping for output
 		MatrixObject out = ec.getMatrixObject(output);
 		out.setFedMapping(in.getFedMapping().copyWithNewID(fr2.getID()));
@@ -298,7 +299,6 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
 			id = fr1.getID();
 		}
 		else {
-
 			if((map.getType() == FederationMap.FType.COL && aop.isColAggregate()) || (map.getType() == FederationMap.FType.ROW && aop.isRowAggregate()))
 				fr1 = new FederatedRequest(RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), in.getDataType());
 			else
@@ -307,11 +307,10 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
 
 		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output, id,
 			new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()}, ExecType.SPARK, true);
-		FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
-		FederatedRequest fr4 = map.cleanup(getTID(), fr2.getID());
+		FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
 
 		//execute federated commands and cleanups
-		Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2, fr3, fr4);
+		Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2, fr3);
 		if( output.isScalar() )
 			ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp, map));
 		else
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 3db1c8b..29b2a17 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -26,7 +26,9 @@ import java.net.URL;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
@@ -105,6 +107,17 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
 			throw new DMLRuntimeException("Federated read needs twice the amount of addresses as ranges "
 				+ "(begin and end): addresses=" + addresses.getLength() + " ranges=" + ranges.getLength());
 
+		//check for duplicate addresses (would lead to overwrite with common variable names)
+		// TODO relax requirement by using different execution contexts per federated data?
+		Set<String> addCheck = new HashSet<>();
+		for( Data dat : addresses.getData() )
+			if( dat instanceof StringObject ) {
+				String address = ((StringObject)dat).getStringValue();
+				if(addCheck.contains(address))
+					LOG.warn("Federated data contains address duplicates: " + addresses);
+				addCheck.add(address);
+			}
+		
 		Types.DataType fedDataType;
 		if(type.equalsIgnoreCase(FED_MATRIX_IDENTIFIER))
 			fedDataType = Types.DataType.MATRIX;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
index 82ac6eb..4f70cce 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
@@ -91,49 +91,51 @@ public class FederatedSumTest extends AutomatedTestBase {
 		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
 		Types.ExecMode platformOld = rtplatform;
 
-		getAndLoadTestConfiguration(TEST_NAME);
-		String HOME = SCRIPT_DIR + TEST_DIR;
-
-		double[][] A = getRandomMatrix(rows / 2, cols, -10, 10, 1, 1);
-		writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows / 2, cols, blocksize, (rows / 2) * cols));
-		int port = getRandomAvailablePort();
-		Thread t = startLocalFedWorkerThread(port);
-
-		// we need the reference file to not be written to hdfs, so we get the correct format
-		rtplatform = Types.ExecMode.SINGLE_NODE;
-		// Run reference dml script with normal matrix for Row/Col sum
-		fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-		programArgs = new String[] {"-args", input("A"), input("A"), expected("R"), expected("C")};
-		runTest(true, false, null, -1);
-
-		// write expected sum
-		double sum = 0;
-		for(double[] doubles : A) {
-			sum += Arrays.stream(doubles).sum();
+		try {
+			getAndLoadTestConfiguration(TEST_NAME);
+			String HOME = SCRIPT_DIR + TEST_DIR;
+	
+			double[][] A = getRandomMatrix(rows / 2, cols, -10, 10, 1, 1);
+			writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows / 2, cols, blocksize, (rows / 2) * cols));
+			int port = getRandomAvailablePort();
+			Thread t = startLocalFedWorkerThread(port);
+	
+			// we need the reference file to not be written to hdfs, so we get the correct format
+			rtplatform = Types.ExecMode.SINGLE_NODE;
+			// Run reference dml script with normal matrix for Row/Col sum
+			fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+			programArgs = new String[] {"-args", input("A"), input("A"), expected("R"), expected("C")};
+			runTest(true, false, null, -1);
+	
+			// write expected sum
+			double sum = 0;
+			for(double[] doubles : A)
+				sum += Arrays.stream(doubles).sum();
+			writeExpectedScalar("S", sum);
+	
+			// reference file should not be written to hdfs, so we set platform here
+			rtplatform = execMode;
+			if(rtplatform == Types.ExecMode.SPARK) {
+				DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+			}
+			TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+			loadTestConfiguration(config);
+			OptimizerUtils.FEDERATED_COMPILATION = federatedCompilation;
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			programArgs = new String[] {"-explain","-nvargs", "in=" + TestUtils.federatedAddress(port, input("A")), "rows=" + rows,
+				"cols=" + cols, "out_S=" + output("S"), "out_R=" + output("R"), "out_C=" + output("C")};
+	
+			runTest(true, false, null, -1);
+	
+			// compare all sums via files
+			compareResults(1e-11);
+	
+			TestUtils.shutdownThread(t);
+			rtplatform = platformOld;
 		}
-		sum *= 2;
-		writeExpectedScalar("S", sum);
-
-		// reference file should not be written to hdfs, so we set platform here
-		rtplatform = execMode;
-		if(rtplatform == Types.ExecMode.SPARK) {
-			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+		finally {
+			OptimizerUtils.FEDERATED_COMPILATION = false;
+			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
 		}
-		TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
-		loadTestConfiguration(config);
-		OptimizerUtils.FEDERATED_COMPILATION = federatedCompilation;
-		fullDMLScriptName = HOME + TEST_NAME + ".dml";
-		programArgs = new String[] {"-nvargs", "in=" + TestUtils.federatedAddress(port, input("A")), "rows=" + rows,
-			"cols=" + cols, "out_S=" + output("S"), "out_R=" + output("R"), "out_C=" + output("C")};
-
-		runTest(true, false, null, -1);
-
-		// compare all sums via files
-		compareResults(1e-11);
-
-		TestUtils.shutdownThread(t);
-		rtplatform = platformOld;
-		OptimizerUtils.FEDERATED_COMPILATION = false;
-		DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
 	}
 }
diff --git a/src/test/scripts/functions/federated/FederatedSumTest.dml b/src/test/scripts/functions/federated/FederatedSumTest.dml
index 37a19f6..385b1dd 100644
--- a/src/test/scripts/functions/federated/FederatedSumTest.dml
+++ b/src/test/scripts/functions/federated/FederatedSumTest.dml
@@ -19,7 +19,7 @@
 #
 #-------------------------------------------------------------
 
-A = federated(addresses=list($in, $in), ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), list($rows, $cols)))
+A = federated(addresses=list($in), ranges=list(list(0, 0), list($rows / 2, $cols)))
 s = sum(A)
 r = rowSums(A)
 c = colSums(A)
diff --git a/src/test/scripts/functions/federated/FederatedSumTestReference.dml b/src/test/scripts/functions/federated/FederatedSumTestReference.dml
index af51717..3684860 100644
--- a/src/test/scripts/functions/federated/FederatedSumTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedSumTestReference.dml
@@ -19,7 +19,7 @@
 #
 #-------------------------------------------------------------
 
-A = rbind(read($1), read($2))
+A = read($1)
 r = rowSums(A)
 c = colSums(A)
 write(r, $3)