You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/09/14 11:20:45 UTC

[systemds] 01/02: [SYSTEMDS-3018] Federated parameterserver print only if failing

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 2c2fedc0b26f53597679e9be2b77e523167444ce
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Tue Sep 14 13:15:41 2021 +0200

    [SYSTEMDS-3018] Federated parameterserver print only if failing
---
 .../federated/paramserv/AvgModelFederatedParamservTest.java    | 10 ++++------
 .../functions/federated/paramserv/FederatedParamservTest.java  |  8 +++-----
 .../federated/paramserv/NbatchesFederatedParamservTest.java    | 10 ++++------
 .../functions/federated/primitives/FederatedRCBindTest.java    |  1 -
 .../primitives/FederatedWeightedUnaryMatrixMultTest.java       |  2 +-
 5 files changed, 12 insertions(+), 19 deletions(-)

diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
index 66482f3..702f632 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
@@ -24,8 +24,6 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
@@ -39,7 +37,7 @@ import org.junit.runners.Parameterized;
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
 public class AvgModelFederatedParamservTest extends AutomatedTestBase {
-	private static final Log LOG = LogFactory.getLog(AvgModelFederatedParamservTest.class.getName());
+	// private static final Log LOG = LogFactory.getLog(AvgModelFederatedParamservTest.class.getName());
 	private final static String TEST_DIR = "functions/federated/paramserv/";
 	private final static String TEST_NAME = "AvgModelFederatedParamservTest";
 	private final static String TEST_CLASS_DIR = TEST_DIR + AvgModelFederatedParamservTest.class.getSimpleName() + "/";
@@ -131,7 +129,7 @@ public class AvgModelFederatedParamservTest extends AutomatedTestBase {
 		// config
 		getAndLoadTestConfiguration(TEST_NAME);
 		String HOME = SCRIPT_DIR + TEST_DIR;
-		setOutputBuffering(false);
+		setOutputBuffering(true);
 
 		int C = 1, Hin = 28, Win = 28;
 		int numLabels = 10;
@@ -201,8 +199,8 @@ public class AvgModelFederatedParamservTest extends AutomatedTestBase {
 				"modelAvg=" +  Boolean.toString(modelAvg).toUpperCase()));
 
 			programArgs = programArgsList.toArray(new String[0]);
-			LOG.debug(runTest(null));
-			Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst());
+			String log = runTest(null).toString();
+			Assert.assertEquals("Test Failed \n" + log, 0, Statistics.getNoOfExecutedSPInst());
 
 			// shut down threads
 			for(int i = 0; i < _numFederatedWorkers; i++) {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index c316214..fd40275 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -24,8 +24,6 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
@@ -39,7 +37,7 @@ import org.junit.runners.Parameterized;
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
 public class FederatedParamservTest extends AutomatedTestBase {
-	private static final Log LOG = LogFactory.getLog(FederatedParamservTest.class.getName());
+	// private static final Log LOG = LogFactory.getLog(FederatedParamservTest.class.getName());
 	private final static String TEST_DIR = "functions/federated/paramserv/";
 	private final static String TEST_NAME = "FederatedParamservTest";
 	private final static String TEST_CLASS_DIR = TEST_DIR + FederatedParamservTest.class.getSimpleName() + "/";
@@ -199,8 +197,8 @@ public class FederatedParamservTest extends AutomatedTestBase {
 					"seed=" + _seed));
 
 			programArgs = programArgsList.toArray(new String[0]);
-			LOG.debug(runTest(null));
-			Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst());
+			String log = runTest(null).toString();
+			Assert.assertEquals("Test Failed \n" + log, 0, Statistics.getNoOfExecutedSPInst());
 			
 			// shut down threads
 			for(int i = 0; i < _numFederatedWorkers; i++) {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
index e2e4f20..9b9f9bb 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
@@ -24,8 +24,6 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
@@ -39,7 +37,7 @@ import org.junit.runners.Parameterized;
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
 public class NbatchesFederatedParamservTest extends AutomatedTestBase {
-	private static final Log LOG = LogFactory.getLog(NbatchesFederatedParamservTest.class.getName());
+	// private static final Log LOG = LogFactory.getLog(NbatchesFederatedParamservTest.class.getName());
 	private final static String TEST_DIR = "functions/federated/paramserv/";
 	private final static String TEST_NAME = "NbatchesFederatedParamservTest";
 	private final static String TEST_CLASS_DIR = TEST_DIR + NbatchesFederatedParamservTest.class.getSimpleName() + "/";
@@ -111,7 +109,7 @@ public class NbatchesFederatedParamservTest extends AutomatedTestBase {
 		// config
 		getAndLoadTestConfiguration(TEST_NAME);
 		String HOME = SCRIPT_DIR + TEST_DIR;
-		setOutputBuffering(false);
+		setOutputBuffering(true);
 
 		int C = 1, Hin = 28, Win = 28;
 		int numLabels = 10;
@@ -181,8 +179,8 @@ public class NbatchesFederatedParamservTest extends AutomatedTestBase {
 				"nbatches=" + _nbatches));
 
 			programArgs = programArgsList.toArray(new String[0]);
-			LOG.debug(runTest(null));
-			Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst());
+			String log = runTest(null).toString();
+			Assert.assertEquals("Test Failed \n" + log,0, Statistics.getNoOfExecutedSPInst());
 
 			// shut down threads
 			for(int i = 0; i < _numFederatedWorkers; i++) {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
index 1470274..04b668d 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
@@ -28,7 +28,6 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
index 8bc9fee..9086edf 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
@@ -182,7 +182,7 @@ public class FederatedWeightedUnaryMatrixMultTest extends AutomatedTestBase
 			runTest(true, false, null, -1);
 
 			// compare the results via files
-			HashMap<CellIndex, Double> refResults	= readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+			HashMap<CellIndex, Double> refResults = readDMLMatrixFromExpectedDir(OUTPUT_NAME);
 			HashMap<CellIndex, Double> fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME);
 			TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref");