You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ki...@apache.org on 2022/07/12 16:54:09 UTC

[systemds] 01/01: [MINOR] Fix Spark ParameterServer

This is an automated email from the ASF dual-hosted git repository.

kinnerebner pushed a commit to branch paramserv_spark_fix
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit f586eaa8b95aefc7c67eea379b69405463632447
Author: Kevin Innerebner <ke...@yahoo.com>
AuthorDate: Mon Jul 11 22:47:01 2022 +0200

    [MINOR] Fix Spark ParameterServer
    
    This patch fixes the Spark execution mode for the parameter server. In commit 28ff18fca2a9258168db7397d56236a5e0d9564b the handling of functions was changed, leading to the parameter server in Spark mode, not finding or sending the functions to the workers properly.
    
    Closes #1662
---
 .../runtime/controlprogram/paramserv/ParamServer.java  | 18 ++++++++++--------
 .../controlprogram/paramserv/ParamservUtils.java       |  3 +++
 .../controlprogram/paramserv/SparkPSWorker.java        |  3 +++
 .../instructions/cp/ParamservBuiltinCPInstruction.java |  4 ++--
 .../test/functions/paramserv/ParamservSparkNNTest.java |  5 ++++-
 5 files changed, 22 insertions(+), 11 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 3957965988..e88a19d964 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -78,7 +78,8 @@ public abstract class ParamServer
 
 	private int _numWorkers;
 	private int _numBackupWorkers;
-	private boolean[] _discardWorkerRes;
+	// number of updates the respective worker is straggling behind
+	private int[] _numUpdatesStraggling;
 	private boolean _modelAvg;
 	private ListObject _accModels = null;
 
@@ -109,7 +110,7 @@ public abstract class ParamServer
 		_numBatchesPerEpoch = numBatchesPerEpoch;
 		_numWorkers = workerNum;
 		_numBackupWorkers = numBackupWorkers;
-		_discardWorkerRes = new boolean[workerNum];
+		_numUpdatesStraggling = new int[workerNum];
 		_modelAvg = modelAvg;
 
 		// broadcast initial model
@@ -118,6 +119,8 @@ public abstract class ParamServer
 
 	protected void setupAggFunc(ExecutionContext ec, String aggFunc) {
 		String[] cfn = DMLProgram.splitFunctionKey(aggFunc);
+		if(cfn.length == 1)
+			cfn = new String[] {null, cfn[0]};
 		String ns = cfn[0];
 		String fname = cfn[1];
 		boolean opt = !ec.getProgram().containsFunctionProgramBlock(ns, fname, false);
@@ -240,10 +243,10 @@ public abstract class ParamServer
 					break;
 				}
 				case SBP: {
-					if(_discardWorkerRes[workerID]) {
+					if(_numUpdatesStraggling[workerID] > 0) {
 						LOG.info("[+] PRAMSERV: discarding result of backup-worker/straggler " + workerID);
 						broadcastModel(workerID);
-						_discardWorkerRes[workerID] = false;
+						_numUpdatesStraggling[workerID]--;
 						break;
 					}
 					setFinishedState(workerID);
@@ -255,7 +258,6 @@ public abstract class ParamServer
 						updateGlobalModel(gradients);
 
 					if(enoughFinished()) {
-						// set flags to throwaway backup worker results
 						tagStragglers();
 						performGlobalGradientUpdate();
 					}
@@ -300,7 +302,7 @@ public abstract class ParamServer
 	private void tagStragglers() {
 		for(int i = 0; i < _finishedStates.length; ++i) {
 			if(!_finishedStates[i])
-				_discardWorkerRes[i] = true;
+				_numUpdatesStraggling[i]++;
 		}
 	}
 
@@ -371,10 +373,10 @@ public abstract class ParamServer
 				case SBP: {
 					// first weight the models based on number of workers
 					ListObject weightParams = weightModels(model, _numWorkers - _numBackupWorkers);
-					if(_discardWorkerRes[workerID]) {
+					if(_numUpdatesStraggling[workerID] > 0) {
 						LOG.info("[+] PRAMSERV: discarding result of backup-worker/straggler " + workerID);
 						broadcastModel(workerID);
-						_discardWorkerRes[workerID] = false;
+						_numUpdatesStraggling[workerID]--;
 						break;
 					}
 					setFinishedState(workerID);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index cfc3a200a5..2a6877d89e 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -268,7 +268,10 @@ public class ParamservUtils {
 			String[] parts = DMLProgram.splitFunctionKey(e.getKey());
 			FunctionProgramBlock fpb = ProgramConverter
 				.createDeepCopyFunctionProgramBlock(e.getValue(), new HashSet<>(), new HashSet<>());
+			fpb._namespace = parts[0];
+			fpb._functionName = parts[1];
 			newProg.addFunctionProgramBlock(parts[0], parts[1], fpb, opt);
+			newProg.addProgramBlock(fpb);
 		}
 		return newProg;
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
index 9e96b45a5b..7823d8811c 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
@@ -76,6 +76,9 @@ public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<
 		_nEpochs = aEpochs;
 		_nbatches = nbatches;
 		_modelAvg = modelAvg;
+		
+		// make SparkPSWorker serializable
+		_tpool = null;
 	}
 
 	@Override
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 1fa83b2a8d..ef45a9c2b3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -661,10 +661,10 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 
 	private int getNumBackupWorkers() {
 		if(!getParameterMap().containsKey(PS_NUM_BACKUP_WORKERS)) {
-			if (!getUpdateType().isSBP())
-				LOG.warn("Specifying number of backup-workers without SBP mode has no effect");
 			return DEFAULT_NUM_BACKUP_WORKERS;
 		}
+		if (!getUpdateType().isSBP())
+			LOG.warn("Specifying number of backup-workers without SBP mode has no effect");
 		return Integer.parseInt(getParam(PS_NUM_BACKUP_WORKERS));
 	}
 
diff --git a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java
index c7f0e39dff..630c3c1ebd 100644
--- a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java
@@ -29,7 +29,6 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 
 @net.jcip.annotations.NotThreadSafe
-@Ignore
 public class ParamservSparkNNTest extends AutomatedTestBase {
 
 	private static final String TEST_NAME1 = "paramserv-test";
@@ -77,12 +76,16 @@ public class ParamservSparkNNTest extends AutomatedTestBase {
 	}
 
 	@Test
+	@Ignore
 	public void testParamservWorkerFailed() {
+		// FIXME: `aggregation` function can't be found (optimized away?)
 		runDMLTest(TEST_NAME2, true, DMLRuntimeException.class, "Invalid indexing by name in unnamed list: worker_err.");
 	}
 
 	@Test
+	@Ignore
 	public void testParamservAggServiceFailed() {
+		// FIXME: `aggregation` function can't be found (optimized away?)
 		runDMLTest(TEST_NAME3, true, DMLRuntimeException.class, "Invalid indexing by name in unnamed list: agg_service_err.");
 	}