You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2022/06/04 21:46:29 UTC
[systemds] branch main updated: [MINOR] Cleanup flaky privacy/FederatedWorkerHandlerTest
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 70c3e5f93d [MINOR] Cleanup flaky privacy/FederatedWorkerHandlerTest
70c3e5f93d is described below
commit 70c3e5f93d4ef22447d765ef261985129ff1a7e2
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sat Jun 4 23:45:20 2022 +0200
[MINOR] Cleanup flaky privacy/FederatedWorkerHandlerTest
---
.../privacy/FederatedWorkerHandlerTest.java | 119 +++++++++------------
1 file changed, 51 insertions(+), 68 deletions(-)
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
index 7339aea931..d23fe8c533 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
@@ -23,6 +23,7 @@ import java.util.Arrays;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
@@ -32,9 +33,7 @@ import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Ignore;
import org.junit.Test;
-import static java.lang.Thread.sleep;
import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
@net.jcip.annotations.NotThreadSafe
public class FederatedWorkerHandlerTest extends AutomatedTestBase {
@@ -49,7 +48,6 @@ public class FederatedWorkerHandlerTest extends AutomatedTestBase {
private final static String TRANSFER_TEST_NAME = "FederatedRCBindTest";
private final static String MATVECMULT_TEST_NAME = "FederatedMultiplyTest";
private static final String FEDERATED_WORKER_HOST = "localhost";
- private static final int FEDERATED_WORKER_PORT = 1222;
private final static int blocksize = 1024;
private final int rows = 10;
@@ -103,20 +101,15 @@ public class FederatedWorkerHandlerTest extends AutomatedTestBase {
private void runGenericScalarTest(String dmlFile, int s, Class<?> expectedException, PrivacyLevel privacyLevel)
{
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- Types.ExecMode platformOld = rtplatform;
+ ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
- Thread t = null;
try {
- // we need the reference file to not be written to hdfs, so we get the correct format
- rtplatform = Types.ExecMode.SINGLE_NODE;
- programArgs = new String[] {"-w", Integer.toString(FEDERATED_WORKER_PORT)};
- t = new Thread(() -> runTest(true, false, null, -1));
- t.start();
- sleep(FED_WORKER_WAIT);
+ int port = getRandomAvailablePort();
+ Thread t = startLocalFedWorkerThread(port);
+
fullDMLScriptName = SCRIPT_DIR + TEST_DIR_SCALAR + dmlFile + ".dml";
programArgs = new String[]{"-checkPrivacy", "-nvargs",
- "in=" + TestUtils.federatedAddress(FEDERATED_WORKER_HOST, FEDERATED_WORKER_PORT, input("M")),
+ "in=" + TestUtils.federatedAddress(FEDERATED_WORKER_HOST, port, input("M")),
"rows=" + Integer.toString(rows), "cols=" + Integer.toString(cols),
"scalar=" + Integer.toString(s),
"out=" + output("R")};
@@ -125,15 +118,12 @@ public class FederatedWorkerHandlerTest extends AutomatedTestBase {
if ( !exceptionExpected )
compareResults();
- } catch (InterruptedException e) {
- fail("InterruptedException thrown" + e.getMessage() + " " + Arrays.toString(e.getStackTrace()));
- } finally {
+ TestUtils.shutdownThread(t);
+ }
+ finally {
assertTrue("The privacy level " + privacyLevel.toString() + " should have been checked during execution",
checkedPrivacyConstraintsContains(privacyLevel));
- rtplatform = platformOld;
- TestUtils.shutdownThread(t);
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ resetExecMode(platformOld);
}
}
@@ -153,57 +143,50 @@ public class FederatedWorkerHandlerTest extends AutomatedTestBase {
}
public void federatedSum(Types.ExecMode execMode, PrivacyLevel privacyLevel, Class<?> expectedException) {
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- Types.ExecMode platformOld = rtplatform;
-
-
- getAndLoadTestConfiguration("aggregation");
- String HOME = SCRIPT_DIR + TEST_DIR_fed;
+ ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
- double[][] A = getRandomMatrix(rows/2, cols, -10, 10, 1, 1);
- writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows/2, cols, blocksize, (rows/2) * cols), new PrivacyConstraint(privacyLevel));
- int port = getRandomAvailablePort();
- Thread t = startLocalFedWorkerThread(port);
-
- // we need the reference file to not be written to hdfs, so we get the correct format
- rtplatform = Types.ExecMode.SINGLE_NODE;
- // Run reference dml script with normal matrix for Row/Col sum
- fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-args", input("A"), input("A"), expected("R"), expected("C")};
- runTest(true, false, null, -1);
-
- // write expected sum
- double sum = 0;
- for(double[] doubles : A) {
- sum += Arrays.stream(doubles).sum();
+ try {
+ getAndLoadTestConfiguration("aggregation");
+ String HOME = SCRIPT_DIR + TEST_DIR_fed;
+
+ double[][] A = getRandomMatrix(rows/2, cols, -10, 10, 1, 1);
+ writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows/2, cols, blocksize, (rows/2) * cols), new PrivacyConstraint(privacyLevel));
+ int port = getRandomAvailablePort();
+ Thread t = startLocalFedWorkerThread(port);
+
+ // Run reference dml script with normal matrix for Row/Col sum
+ fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-args", input("A"), input("A"), expected("R"), expected("C")};
+ runTest(true, false, null, -1);
+
+ // write expected sum
+ double sum = 0;
+ for(double[] doubles : A) {
+ sum += Arrays.stream(doubles).sum();
+ }
+
+ if ( expectedException == null )
+ writeExpectedScalar("S", sum);
+
+ TestConfiguration config = availableTestConfigurations.get("aggregation");
+ loadTestConfiguration(config);
+ fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + ".dml";
+ programArgs = new String[] {"-checkPrivacy", "-nvargs", "in=" + TestUtils.federatedAddress(port, input("A")), "rows=" + rows,
+ "cols=" + cols, "out_S=" + output("S"), "out_R=" + output("R"), "out_C=" + output("C")};
+
+ runTest(true, (expectedException != null), expectedException, -1);
+
+ // compare all sums via files
+ if ( expectedException == null )
+ compareResults(1e-11);
+
+ assertTrue("The privacy level " + privacyLevel.toString() + " should have been checked during execution",
+ checkedPrivacyConstraintsContains(privacyLevel));
+ TestUtils.shutdownThread(t);
}
-
- if ( expectedException == null )
- writeExpectedScalar("S", sum);
-
- // reference file should not be written to hdfs, so we set platform here
- rtplatform = execMode;
- if(rtplatform == Types.ExecMode.SPARK) {
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ finally {
+ resetExecMode(platformOld);
}
- TestConfiguration config = availableTestConfigurations.get("aggregation");
- loadTestConfiguration(config);
- fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + ".dml";
- programArgs = new String[] {"-checkPrivacy", "-nvargs", "in=" + TestUtils.federatedAddress(port, input("A")), "rows=" + rows,
- "cols=" + cols, "out_S=" + output("S"), "out_R=" + output("R"), "out_C=" + output("C")};
-
- runTest(true, (expectedException != null), expectedException, -1);
-
- // compare all sums via files
- if ( expectedException == null )
- compareResults(1e-11);
-
- assertTrue("The privacy level " + privacyLevel.toString() + " should have been checked during execution",
- checkedPrivacyConstraintsContains(privacyLevel));
-
- TestUtils.shutdownThread(t);
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
@Test