You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/09/13 11:42:20 UTC
[systemds] branch master updated: [SYSTEMDS-3018] Fix federated
paramserv setup of model update functions
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new e112345 [SYSTEMDS-3018] Fix federated paramserv setup of model update functions
e112345 is described below
commit e112345edaebced1f419c2e4cd7abad08dba6599
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Mon Sep 13 13:42:05 2021 +0200
[SYSTEMDS-3018] Fix federated paramserv setup of model update functions
This patch fixes inconsistencies in federated paramserv with model
averaging.
---
.../controlprogram/paramserv/FederatedPSControlThread.java | 2 +-
.../federated/paramserv/AvgModelFederatedParamservTest.java | 8 ++------
2 files changed, 3 insertions(+), 7 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 85bb745..ea8f0e8 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -503,7 +503,7 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
// recreate aggregation instruction and output if needed
Instruction aggregationInstruction = null;
DataIdentifier aggregationOutput = null;
- if(_localUpdate && _numBatchesToCompute > 1) {
+ if(_localUpdate && _numBatchesToCompute > 1 | modelAvg) {
func = ec.getProgram().getFunctionProgramBlock(namespace, aggFunc, opt);
inputs = func.getInputParams();
outputs = func.getOutputParams();
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
index d96097e..66482f3 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
@@ -32,7 +32,6 @@ import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.utils.Statistics;
import org.junit.Assert;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -118,15 +117,11 @@ public class AvgModelFederatedParamservTest extends AutomatedTestBase {
}
@Test
- @Ignore
- // TODO FIX ME
public void AvgmodelfederatedParamservSingleNode() {
AvgmodelfederatedParamserv(ExecMode.SINGLE_NODE, true);
}
@Test
- @Ignore
- // TODO FIX ME
public void AvgmodelfederatedParamservHybrid() {
AvgmodelfederatedParamserv(ExecMode.HYBRID, true);
}
@@ -149,7 +144,8 @@ public class AvgModelFederatedParamservTest extends AutomatedTestBase {
List<Thread> threads = new ArrayList<>();
for(int i = 0; i < _numFederatedWorkers; i++) {
ports.add(getRandomAvailablePort());
- threads.add(startLocalFedWorkerThread(ports.get(i), FED_WORKER_WAIT_S));
+ threads.add(startLocalFedWorkerThread(ports.get(i),
+ i==(_numFederatedWorkers-1) ? FED_WORKER_WAIT : FED_WORKER_WAIT_S));
}
// generate test data