You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2022/06/04 21:46:29 UTC

[systemds] branch main updated: [MINOR] Cleanup flaky privacy/FederatedWorkerHandlerTest

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 70c3e5f93d [MINOR] Cleanup flaky privacy/FederatedWorkerHandlerTest
70c3e5f93d is described below

commit 70c3e5f93d4ef22447d765ef261985129ff1a7e2
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sat Jun 4 23:45:20 2022 +0200

    [MINOR] Cleanup flaky privacy/FederatedWorkerHandlerTest
---
 .../privacy/FederatedWorkerHandlerTest.java        | 119 +++++++++------------
 1 file changed, 51 insertions(+), 68 deletions(-)

diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
index 7339aea931..d23fe8c533 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
@@ -23,6 +23,7 @@ import java.util.Arrays;
 
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
@@ -32,9 +33,7 @@ import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Ignore;
 import org.junit.Test;
-import static java.lang.Thread.sleep;
 import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
 
 @net.jcip.annotations.NotThreadSafe
 public class FederatedWorkerHandlerTest extends AutomatedTestBase {
@@ -49,7 +48,6 @@ public class FederatedWorkerHandlerTest extends AutomatedTestBase {
 	private final static String TRANSFER_TEST_NAME = "FederatedRCBindTest";
 	private final static String MATVECMULT_TEST_NAME = "FederatedMultiplyTest";
 	private static final String FEDERATED_WORKER_HOST = "localhost";
-	private static final int FEDERATED_WORKER_PORT = 1222;
 
 	private final static int blocksize = 1024;
 	private final int rows = 10;
@@ -103,20 +101,15 @@ public class FederatedWorkerHandlerTest extends AutomatedTestBase {
 
 	private void runGenericScalarTest(String dmlFile, int s, Class<?> expectedException, PrivacyLevel privacyLevel)
 	{
-		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-		Types.ExecMode platformOld = rtplatform;
+		ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
 
-		Thread t = null;
 		try {
-			// we need the reference file to not be written to hdfs, so we get the correct format
-			rtplatform = Types.ExecMode.SINGLE_NODE;
-			programArgs = new String[] {"-w", Integer.toString(FEDERATED_WORKER_PORT)};
-			t = new Thread(() -> runTest(true, false, null, -1));
-			t.start();
-			sleep(FED_WORKER_WAIT);
+			int port = getRandomAvailablePort();
+			Thread t = startLocalFedWorkerThread(port);
+
 			fullDMLScriptName = SCRIPT_DIR + TEST_DIR_SCALAR + dmlFile + ".dml";
 			programArgs = new String[]{"-checkPrivacy", "-nvargs",
-					"in=" + TestUtils.federatedAddress(FEDERATED_WORKER_HOST, FEDERATED_WORKER_PORT, input("M")),
+					"in=" + TestUtils.federatedAddress(FEDERATED_WORKER_HOST, port, input("M")),
 					"rows=" + Integer.toString(rows), "cols=" + Integer.toString(cols),
 					"scalar=" + Integer.toString(s),
 					"out=" + output("R")};
@@ -125,15 +118,12 @@ public class FederatedWorkerHandlerTest extends AutomatedTestBase {
 
 			if ( !exceptionExpected )
 				compareResults();
-		} catch (InterruptedException e) {
-			fail("InterruptedException thrown" + e.getMessage() + " " + Arrays.toString(e.getStackTrace()));
-		} finally {
+			TestUtils.shutdownThread(t);
+		}
+		finally {
 			assertTrue("The privacy level " + privacyLevel.toString() + " should have been checked during execution",
 				checkedPrivacyConstraintsContains(privacyLevel));
-			rtplatform = platformOld;
-			TestUtils.shutdownThread(t);
-			rtplatform = platformOld;
-			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+			resetExecMode(platformOld);
 		}
 	}
 
@@ -153,57 +143,50 @@ public class FederatedWorkerHandlerTest extends AutomatedTestBase {
 	}
 
 	public void federatedSum(Types.ExecMode execMode, PrivacyLevel privacyLevel, Class<?> expectedException) {
-		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-		Types.ExecMode platformOld = rtplatform;
-
-
-		getAndLoadTestConfiguration("aggregation");
-		String HOME = SCRIPT_DIR + TEST_DIR_fed;
+		ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
 
-		double[][] A = getRandomMatrix(rows/2, cols, -10, 10, 1, 1);
-		writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows/2, cols, blocksize, (rows/2) * cols), new PrivacyConstraint(privacyLevel));
-		int port = getRandomAvailablePort();
-		Thread t = startLocalFedWorkerThread(port);
-
-		// we need the reference file to not be written to hdfs, so we get the correct format
-		rtplatform = Types.ExecMode.SINGLE_NODE;
-		// Run reference dml script with normal matrix for Row/Col sum
-		fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + "Reference.dml";
-		programArgs = new String[] {"-args", input("A"), input("A"), expected("R"), expected("C")};
-		runTest(true, false, null, -1);
-
-		// write expected sum
-		double sum = 0;
-		for(double[] doubles : A) {
-			sum += Arrays.stream(doubles).sum();
+		try {
+			getAndLoadTestConfiguration("aggregation");
+			String HOME = SCRIPT_DIR + TEST_DIR_fed;
+	
+			double[][] A = getRandomMatrix(rows/2, cols, -10, 10, 1, 1);
+			writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows/2, cols, blocksize, (rows/2) * cols), new PrivacyConstraint(privacyLevel));
+			int port = getRandomAvailablePort();
+			Thread t = startLocalFedWorkerThread(port);
+	
+			// Run reference dml script with normal matrix for Row/Col sum
+			fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + "Reference.dml";
+			programArgs = new String[] {"-args", input("A"), input("A"), expected("R"), expected("C")};
+			runTest(true, false, null, -1);
+	
+			// write expected sum
+			double sum = 0;
+			for(double[] doubles : A) {
+				sum += Arrays.stream(doubles).sum();
+			}
+	
+			if ( expectedException == null )
+				writeExpectedScalar("S", sum);
+	
+			TestConfiguration config = availableTestConfigurations.get("aggregation");
+			loadTestConfiguration(config);
+			fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + ".dml";
+			programArgs = new String[] {"-checkPrivacy", "-nvargs", "in=" + TestUtils.federatedAddress(port, input("A")), "rows=" + rows,
+				"cols=" + cols, "out_S=" + output("S"), "out_R=" + output("R"), "out_C=" + output("C")};
+	
+			runTest(true, (expectedException != null), expectedException, -1);
+	
+			// compare all sums via files
+			if ( expectedException == null )
+				compareResults(1e-11);
+	
+			assertTrue("The privacy level " + privacyLevel.toString() + " should have been checked during execution",
+				checkedPrivacyConstraintsContains(privacyLevel));
+			TestUtils.shutdownThread(t);
 		}
-
-		if ( expectedException == null )
-			writeExpectedScalar("S", sum);
-
-		// reference file should not be written to hdfs, so we set platform here
-		rtplatform = execMode;
-		if(rtplatform == Types.ExecMode.SPARK) {
-			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+		finally {
+			resetExecMode(platformOld);
 		}
-		TestConfiguration config = availableTestConfigurations.get("aggregation");
-		loadTestConfiguration(config);
-		fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + ".dml";
-		programArgs = new String[] {"-checkPrivacy", "-nvargs", "in=" + TestUtils.federatedAddress(port, input("A")), "rows=" + rows,
-			"cols=" + cols, "out_S=" + output("S"), "out_R=" + output("R"), "out_C=" + output("C")};
-
-		runTest(true, (expectedException != null), expectedException, -1);
-
-		// compare all sums via files
-		if ( expectedException == null )
-			compareResults(1e-11);
-
-		assertTrue("The privacy level " + privacyLevel.toString() + " should have been checked during execution",
-			checkedPrivacyConstraintsContains(privacyLevel));
-
-		TestUtils.shutdownThread(t);
-		rtplatform = platformOld;
-		DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
 	}
 
 	@Test