You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/09/13 11:42:20 UTC

[systemds] branch master updated: [SYSTEMDS-3018] Fix federated paramserv setup of model update functions

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new e112345  [SYSTEMDS-3018] Fix federated paramserv setup of model update functions
e112345 is described below

commit e112345edaebced1f419c2e4cd7abad08dba6599
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Mon Sep 13 13:42:05 2021 +0200

    [SYSTEMDS-3018] Fix federated paramserv setup of model update functions
    
    This patch fixes inconsistencies in federated paramserv with model
    averaging.
---
 .../controlprogram/paramserv/FederatedPSControlThread.java        | 2 +-
 .../federated/paramserv/AvgModelFederatedParamservTest.java       | 8 ++------
 2 files changed, 3 insertions(+), 7 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 85bb745..ea8f0e8 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -503,7 +503,7 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 			// recreate aggregation instruction and output if needed
 			Instruction aggregationInstruction = null;
 			DataIdentifier aggregationOutput = null;
-			if(_localUpdate && _numBatchesToCompute > 1) {
+			if(_localUpdate && _numBatchesToCompute > 1 | modelAvg) {
 				func = ec.getProgram().getFunctionProgramBlock(namespace, aggFunc, opt);
 				inputs = func.getInputParams();
 				outputs = func.getOutputParams();
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
index d96097e..66482f3 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
@@ -32,7 +32,6 @@ import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.apache.sysds.utils.Statistics;
 import org.junit.Assert;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -118,15 +117,11 @@ public class AvgModelFederatedParamservTest extends AutomatedTestBase {
 	}
 
 	@Test
-	@Ignore
-	// TODO FIX ME
 	public void AvgmodelfederatedParamservSingleNode() {
 		AvgmodelfederatedParamserv(ExecMode.SINGLE_NODE, true);
 	}
 
 	@Test
-	@Ignore
-	// TODO FIX ME
 	public void AvgmodelfederatedParamservHybrid() {
 		AvgmodelfederatedParamserv(ExecMode.HYBRID, true);
 	}
@@ -149,7 +144,8 @@ public class AvgModelFederatedParamservTest extends AutomatedTestBase {
 			List<Thread> threads = new ArrayList<>();
 			for(int i = 0; i < _numFederatedWorkers; i++) {
 				ports.add(getRandomAvailablePort());
-				threads.add(startLocalFedWorkerThread(ports.get(i), FED_WORKER_WAIT_S));
+				threads.add(startLocalFedWorkerThread(ports.get(i),
+					i==(_numFederatedWorkers-1) ? FED_WORKER_WAIT : FED_WORKER_WAIT_S));
 			}
 
 			// generate test data