You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/09/14 11:20:45 UTC
[systemds] 01/02: [SYSTEMDS-3018] Federated parameterserver print
only if failing
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 2c2fedc0b26f53597679e9be2b77e523167444ce
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Tue Sep 14 13:15:41 2021 +0200
[SYSTEMDS-3018] Federated parameterserver print only if failing
---
.../federated/paramserv/AvgModelFederatedParamservTest.java | 10 ++++------
.../functions/federated/paramserv/FederatedParamservTest.java | 8 +++-----
.../federated/paramserv/NbatchesFederatedParamservTest.java | 10 ++++------
.../functions/federated/primitives/FederatedRCBindTest.java | 1 -
.../primitives/FederatedWeightedUnaryMatrixMultTest.java | 2 +-
5 files changed, 12 insertions(+), 19 deletions(-)
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
index 66482f3..702f632 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
@@ -24,8 +24,6 @@ import java.util.Arrays;
import java.util.Collection;
import java.util.List;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
@@ -39,7 +37,7 @@ import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class AvgModelFederatedParamservTest extends AutomatedTestBase {
- private static final Log LOG = LogFactory.getLog(AvgModelFederatedParamservTest.class.getName());
+ // private static final Log LOG = LogFactory.getLog(AvgModelFederatedParamservTest.class.getName());
private final static String TEST_DIR = "functions/federated/paramserv/";
private final static String TEST_NAME = "AvgModelFederatedParamservTest";
private final static String TEST_CLASS_DIR = TEST_DIR + AvgModelFederatedParamservTest.class.getSimpleName() + "/";
@@ -131,7 +129,7 @@ public class AvgModelFederatedParamservTest extends AutomatedTestBase {
// config
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
- setOutputBuffering(false);
+ setOutputBuffering(true);
int C = 1, Hin = 28, Win = 28;
int numLabels = 10;
@@ -201,8 +199,8 @@ public class AvgModelFederatedParamservTest extends AutomatedTestBase {
"modelAvg=" + Boolean.toString(modelAvg).toUpperCase()));
programArgs = programArgsList.toArray(new String[0]);
- LOG.debug(runTest(null));
- Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst());
+ String log = runTest(null).toString();
+ Assert.assertEquals("Test Failed \n" + log, 0, Statistics.getNoOfExecutedSPInst());
// shut down threads
for(int i = 0; i < _numFederatedWorkers; i++) {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index c316214..fd40275 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -24,8 +24,6 @@ import java.util.Arrays;
import java.util.Collection;
import java.util.List;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
@@ -39,7 +37,7 @@ import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class FederatedParamservTest extends AutomatedTestBase {
- private static final Log LOG = LogFactory.getLog(FederatedParamservTest.class.getName());
+ // private static final Log LOG = LogFactory.getLog(FederatedParamservTest.class.getName());
private final static String TEST_DIR = "functions/federated/paramserv/";
private final static String TEST_NAME = "FederatedParamservTest";
private final static String TEST_CLASS_DIR = TEST_DIR + FederatedParamservTest.class.getSimpleName() + "/";
@@ -199,8 +197,8 @@ public class FederatedParamservTest extends AutomatedTestBase {
"seed=" + _seed));
programArgs = programArgsList.toArray(new String[0]);
- LOG.debug(runTest(null));
- Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst());
+ String log = runTest(null).toString();
+ Assert.assertEquals("Test Failed \n" + log, 0, Statistics.getNoOfExecutedSPInst());
// shut down threads
for(int i = 0; i < _numFederatedWorkers; i++) {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
index e2e4f20..9b9f9bb 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
@@ -24,8 +24,6 @@ import java.util.Arrays;
import java.util.Collection;
import java.util.List;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
@@ -39,7 +37,7 @@ import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class NbatchesFederatedParamservTest extends AutomatedTestBase {
- private static final Log LOG = LogFactory.getLog(NbatchesFederatedParamservTest.class.getName());
+ // private static final Log LOG = LogFactory.getLog(NbatchesFederatedParamservTest.class.getName());
private final static String TEST_DIR = "functions/federated/paramserv/";
private final static String TEST_NAME = "NbatchesFederatedParamservTest";
private final static String TEST_CLASS_DIR = TEST_DIR + NbatchesFederatedParamservTest.class.getSimpleName() + "/";
@@ -111,7 +109,7 @@ public class NbatchesFederatedParamservTest extends AutomatedTestBase {
// config
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
- setOutputBuffering(false);
+ setOutputBuffering(true);
int C = 1, Hin = 28, Win = 28;
int numLabels = 10;
@@ -181,8 +179,8 @@ public class NbatchesFederatedParamservTest extends AutomatedTestBase {
"nbatches=" + _nbatches));
programArgs = programArgsList.toArray(new String[0]);
- LOG.debug(runTest(null));
- Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst());
+ String log = runTest(null).toString();
+ Assert.assertEquals("Test Failed \n" + log,0, Statistics.getNoOfExecutedSPInst());
// shut down threads
for(int i = 0; i < _numFederatedWorkers; i++) {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
index 1470274..04b668d 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
@@ -28,7 +28,6 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
index 8bc9fee..9086edf 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
@@ -182,7 +182,7 @@ public class FederatedWeightedUnaryMatrixMultTest extends AutomatedTestBase
runTest(true, false, null, -1);
// compare the results via files
- HashMap<CellIndex, Double> refResults = readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+ HashMap<CellIndex, Double> refResults = readDMLMatrixFromExpectedDir(OUTPUT_NAME);
HashMap<CellIndex, Double> fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref");