You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ki...@apache.org on 2022/07/08 12:57:51 UTC

[systemds] branch main updated: [SYSTEMDS-3404] Synchronous w/ Backup Worker for ParamServer

This is an automated email from the ASF dual-hosted git repository.

kinnerebner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 003de8596a [SYSTEMDS-3404] Synchronous w/ Backup Worker for ParamServer
003de8596a is described below

commit 003de8596a4cec04122219c49cb11fa22f34b30f
Author: Kevin Innerebner <ke...@yahoo.com>
AuthorDate: Tue Jul 5 17:22:48 2022 +0200

    [SYSTEMDS-3404] Synchronous w/ Backup Worker for ParamServer
    
    This adds a new update type to the parameter server implementation. This update type is similar to the synchronous mode, but ignores some stragglers results, with the benefit of starting the next iteration early for the other workers. It is called SBP, referring to "Synchronous with Backup-workers Parallel".
    
    There are currently still a few limitations:
    1. Stragglers are not stopped early, their result is thrown away when they return.
    2. There is no balancing mechanism in place for unbalanced federated workloads. The largest federated partition will probably be thrown away every time.
    We can think about fixing some of those limitations or leave it up to the end user.
    
    Closes #1653
---
 .../ParameterizedBuiltinFunctionExpression.java    |   2 +-
 .../java/org/apache/sysds/parser/Statement.java    |  10 +-
 .../controlprogram/paramserv/HEParamServer.java    |   9 +-
 .../controlprogram/paramserv/LocalParamServer.java |   9 +-
 .../controlprogram/paramserv/ParamServer.java      | 187 ++++++++++++++++-----
 .../cp/ParamservBuiltinCPInstruction.java          |  57 +++++--
 .../paramserv/AvgModelFederatedParamservTest.java  |   3 +
 .../paramserv/EncryptedFederatedParamservTest.java |   3 +
 .../paramserv/FederatedParamservTest.java          |   3 +
 .../paramserv/NbatchesFederatedParamservTest.java  |   2 +
 .../paramserv/ParamservLocalNNAveragingTest.java   |  25 +++
 .../functions/paramserv/ParamservLocalNNTest.java  |  20 +++
 .../ParamservLocalNNTestwithNbatches.java          |  10 ++
 .../test/functions/paramserv/ParamservSBPTest.java |  75 +++++++++
 .../functions/paramserv/ParamservSparkNNTest.java  |  10 ++
 .../ParamservSparkNNTestwithNbatches.java          |   7 +-
 .../functions/paramserv/paramserv-all-args.dml     |   2 +-
 .../{paramserv-all-args.dml => paramserv-sbp.dml}  |  37 ++--
 18 files changed, 391 insertions(+), 80 deletions(-)

diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 56275fd3b9..c83b6f3911 100644
--- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -321,7 +321,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
 			Statement.PS_VAL_FUN, Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
 			Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
 			Statement.PS_FED_WEIGHTING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING, Statement.PS_SEED, Statement.PS_NBATCHES,
-			Statement.PS_MODELAVG, Statement.PS_HE);
+			Statement.PS_MODELAVG, Statement.PS_HE, Statement.PS_NUM_BACKUP_WORKERS);
 		checkInvalidParameters(getOpCode(), getVarParams(), valid);
 
 		// check existence and correctness of parameters
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java
index d22a540180..97a1a3dc68 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -77,18 +77,26 @@ public abstract class Statement implements ParseInfo
 	public static final String PS_MODELAVG = "modelAvg";
 	public static final String PS_NBATCHES = "nbatches";
 	public static final String PS_HE = "he";
+	public static final String PS_NUM_BACKUP_WORKERS = "num_backup_workers";
+
 	public enum PSModeType {
 		FEDERATED, LOCAL, REMOTE_SPARK
 	}
 	public static final String PS_UPDATE_TYPE = "utype";
 	public enum PSUpdateType {
-		BSP, ASP, SSP;
+		BSP, // Bulk Synchronous Parallel
+		ASP, // Asynchronous Parallel
+		SSP, // Stale-Synchronous Parallel
+		SBP; // Synchronous w/ Backup-workers Parallel
 		public boolean isBSP() {
 			return this == BSP;
 		}
 		public boolean isASP() {
 			return this == ASP;
 		}
+		public boolean isSBP() {
+			return this == SBP;
+		}
 	}
 	public static final String PS_FREQUENCY = "freq";
 	public enum PSFrequency {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java
index 4e873abdb6..f5846a400d 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.java
@@ -50,18 +50,19 @@ public class HEParamServer extends LocalParamServer {
 
 	public static HEParamServer create(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
 		Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
-		MatrixObject valFeatures, MatrixObject valLabels, int nbatches)
+		MatrixObject valFeatures, MatrixObject valLabels, int nbatches, int numBackupWorkers)
 	{
 		NativeHEHelper.initialize();
 		return new HEParamServer(model, aggFunc, updateType, freq, ec,
-				workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches);
+				workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, numBackupWorkers);
 	}
 
 	private HEParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
 		Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc,
-		int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches)
+		int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, int numBackupWorkers)
 	{
-		super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, true);
+		super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels,
+			nbatches, true, numBackupWorkers);
 
 		_seal_server = new SEALServer();
 
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
index 9fd49ca0d1..b5b012048d 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -33,17 +33,18 @@ public class LocalParamServer extends ParamServer {
 
 	public static LocalParamServer create(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
 		Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
-		MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg)
+		MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg, int numBackupWorkers)
 	{
 		return new LocalParamServer(model, aggFunc, updateType, freq, ec,
-			workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg);
+			workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg, numBackupWorkers);
 	}
 
 	protected LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
 		Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
-		MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg)
+		MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg, int numBackupWorkers)
 	{
-		super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg);
+		super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels,
+			nbatches, modelAvg, numBackupWorkers);
 	}
 
 	@Override
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 0e09fabf30..3957965988 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -77,13 +77,16 @@ public abstract class ParamServer
 	private int _numBatchesPerEpoch;
 
 	private int _numWorkers;
+	private int _numBackupWorkers;
+	private boolean[] _discardWorkerRes;
 	private boolean _modelAvg;
 	private ListObject _accModels = null;
 
 	protected ParamServer() {}
 
 	protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
-		Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,	MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg)
+		Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
+		MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg, int numBackupWorkers)
 	{
 		// init worker queues and global model
 		_modelMap = new HashMap<>(workerNum);
@@ -105,6 +108,8 @@ public abstract class ParamServer
 		}
 		_numBatchesPerEpoch = numBatchesPerEpoch;
 		_numWorkers = workerNum;
+		_numBackupWorkers = numBackupWorkers;
+		_discardWorkerRes = new boolean[workerNum];
 		_modelAvg = modelAvg;
 
 		// broadcast initial model
@@ -207,39 +212,8 @@ public abstract class ParamServer
 					else
 						updateGlobalModel(gradients);
 
-					if (allFinished()) {
-						// Update the global model with accrued gradients
-						if( ACCRUE_BSP_GRADIENTS ) {
-							updateGlobalModel(_accGradients);
-							_accGradients = null;
-						}
-
-						// This if has grown to be quite complex its function is rather simple. Validate at the end of each epoch
-						// In the BSP batch case that occurs after the sync counter reaches the number of batches and in the
-						// BSP epoch case every time
-						if (_numBatchesPerEpoch != -1 &&
-							(_freq == Statement.PSFrequency.EPOCH ||
-							(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))||
-							(_freq == Statement.PSFrequency.NBATCHES)) {
-
-						if(LOG.isInfoEnabled())
-								LOG.info("[+] PARAMSERV: completed EPOCH " + _epochCounter);
-
-							time_epoch();
-
-							if(_validationPossible)
-								validate();
-
-							_epochCounter++;
-							_syncCounter = 0;
-						}
-
-						// Broadcast the updated model
-						resetFinishedStates();
-						broadcastModel(true);
-						if (LOG.isDebugEnabled())
-							LOG.debug("Global parameter is broadcasted successfully.");
-					}
+					if (allFinished())
+						performGlobalGradientUpdate();
 					break;
 				}
 				case ASP: {
@@ -265,6 +239,28 @@ public abstract class ParamServer
 					broadcastModel(workerID);
 					break;
 				}
+				case SBP: {
+					if(_discardWorkerRes[workerID]) {
+						LOG.info("[+] PRAMSERV: discarding result of backup-worker/straggler " + workerID);
+						broadcastModel(workerID);
+						_discardWorkerRes[workerID] = false;
+						break;
+					}
+					setFinishedState(workerID);
+
+					// Accumulate the intermediate gradients
+					if(ACCRUE_BSP_GRADIENTS)
+						_accGradients = ParamservUtils.accrueGradients(_accGradients, gradients, true);
+					else
+						updateGlobalModel(gradients);
+
+					if(enoughFinished()) {
+						// set flags to throwaway backup worker results
+						tagStragglers();
+						performGlobalGradientUpdate();
+					}
+					break;
+				}
 				default:
 					throw new DMLRuntimeException("Unsupported update: " + _updateType.name());
 			}
@@ -274,6 +270,50 @@ public abstract class ParamServer
 		}
 	}
 
+	private void performGlobalGradientUpdate() {
+		// Update the global model with accrued gradients
+		if(ACCRUE_BSP_GRADIENTS) {
+			updateGlobalModel(_accGradients);
+			_accGradients = null;
+		}
+
+		if(finishedEpoch()) {
+			if(LOG.isInfoEnabled())
+				LOG.info("[+] PARAMSERV: completed EPOCH " + _epochCounter);
+
+			time_epoch();
+
+			if(_validationPossible)
+				validate();
+
+			_epochCounter++;
+			_syncCounter = 0;
+		}
+
+		// Broadcast the updated model
+		broadcastModel(_finishedStates);
+		resetFinishedStates();
+		if(LOG.isDebugEnabled())
+			LOG.debug("Global parameter is broadcasted successfully.");
+	}
+
+	private void tagStragglers() {
+		for(int i = 0; i < _finishedStates.length; ++i) {
+			if(!_finishedStates[i])
+				_discardWorkerRes[i] = true;
+		}
+	}
+
+	private boolean finishedEpoch() {
+		// Validate at the end of each epoch
+		// In the BSP batch case that occurs after the sync counter reaches the number of batches and in the
+		// BSP epoch case every time
+		return _numBatchesPerEpoch != -1 &&
+			(_freq == Statement.PSFrequency.EPOCH ||
+				(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0)) ||
+			(_freq == Statement.PSFrequency.NBATCHES);
+	}
+
 	private void updateGlobalModel(ListObject gradients) {
 		Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
 		_model = updateLocalModel(_ec, gradients, _model);
@@ -314,10 +354,10 @@ public abstract class ParamServer
 			}
 			Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
 
-			//first weight the models based on number of workers
-			ListObject weightParams = weightModels(model, _numWorkers);
 			switch(_updateType) {
 				case BSP: {
+					//first weight the models based on number of workers
+					ListObject weightParams = weightModels(model, _numWorkers);
 					setFinishedState(workerID);
 					// second Accumulate the given weightModels into the accrued models
 					_accModels = ParamservUtils.accrueGradients(_accModels, weightParams, true);
@@ -328,6 +368,26 @@ public abstract class ParamServer
 					}
 					break;
 				}
+				case SBP: {
+					// first weight the models based on number of workers
+					ListObject weightParams = weightModels(model, _numWorkers - _numBackupWorkers);
+					if(_discardWorkerRes[workerID]) {
+						LOG.info("[+] PRAMSERV: discarding result of backup-worker/straggler " + workerID);
+						broadcastModel(workerID);
+						_discardWorkerRes[workerID] = false;
+						break;
+					}
+					setFinishedState(workerID);
+					// second Accumulate the given weightModels into the accrued models
+					_accModels = ParamservUtils.accrueGradients(_accModels, weightParams, true);
+
+					if(enoughFinished()) {
+						tagStragglers();
+						updateAndBroadcastModel(_accModels, tAgg, _finishedStates);
+						resetFinishedStates();
+					}
+					break;
+				}
 				case ASP:
 					throw new NotImplementedException();
 
@@ -341,15 +401,28 @@ public abstract class ParamServer
 	}
 
 	protected void updateAndBroadcastModel(ListObject new_model, Timing tAgg) {
+		updateAndBroadcastModel(new_model, tAgg, null);
+	}
+
+	/**
+	 * Update the model and broadcast to (possibly a subset) the workers.
+	 * 
+	 * @param new_model           the new model
+	 * @param tAgg                time for statistics
+	 * @param workerBroadcastMask if null, broadcast to all workers, otherwise only to the ids with
+	 *                            <code>workerBroadcastMask[workerId] == true</code>
+	 */
+	protected void updateAndBroadcastModel(ListObject new_model, Timing tAgg, boolean[] workerBroadcastMask) {
 		_model = setParams(_ec, new_model, _model);
-		if (DMLScript.STATISTICS && tAgg != null)
+		if(DMLScript.STATISTICS && tAgg != null)
 			ParamServStatistics.accAggregationTime((long) tAgg.stop());
-		_accModels = null; //reset for next accumulation
+		_accModels = null; // reset for next accumulation
 
 		// This if has grown to be quite complex its function is rather simple. Validate at the end of each epoch
 		// In the BSP batch case that occurs after the sync counter reaches the number of batches and in the
 		// BSP epoch case every time
-		if(_numBatchesPerEpoch != -1 && (_freq == Statement.PSFrequency.EPOCH || (_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+		if(_numBatchesPerEpoch != -1 && (_freq == Statement.PSFrequency.EPOCH ||
+			(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
 
 			if(LOG.isInfoEnabled())
 				LOG.info("[+] PARAMSERV: completed EPOCH " + _epochCounter);
@@ -361,7 +434,10 @@ public abstract class ParamServer
 			_syncCounter = 0;
 		}
 		// Broadcast the updated model
-		broadcastModel(true);
+		if(workerBroadcastMask == null)
+			broadcastModel(true);
+		else
+			broadcastModel(workerBroadcastMask);
 		if(LOG.isDebugEnabled())
 			LOG.debug("Global parameter is broadcasted successfully ");
 	}
@@ -395,10 +471,21 @@ public abstract class ParamServer
 		return accModels;
 	}
 
-		private boolean allFinished() {
+	private boolean allFinished() {
 		return !ArrayUtils.contains(_finishedStates, false);
 	}
 
+	private boolean enoughFinished() {
+		if(_finishedStates.length == 1)
+			return _finishedStates[0];
+		int numFinished = 0;
+		for(boolean finished : _finishedStates) {
+			if(finished)
+				numFinished++;
+		}
+		return _numWorkers - numFinished <= _numBackupWorkers;
+	}
+
 	private void resetFinishedStates() {
 		Arrays.fill(_finishedStates, false);
 	}
@@ -421,6 +508,24 @@ public abstract class ParamServer
 		});
 	}
 
+	/**
+	 * Broadcast model for a selection of workers
+	 * 
+	 * @param mask the mask being true for all workers that should get the updated models
+	 */
+	private void broadcastModel(boolean[] mask) {
+		IntStream stream = IntStream.range(0, _modelMap.size());
+		stream.parallel().forEach(workerID -> {
+			try {
+				if(mask[workerID])
+					broadcastModel(workerID);
+			}
+			catch(InterruptedException e) {
+				throw new DMLRuntimeException("Paramserv func: some error occurred when broadcasting model", e);
+			}
+		});
+	}
+
 	private void broadcastModel(int workerID) throws InterruptedException {
 		Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null;
 		//broadcast copy of model to specific worker, cleaned up by worker
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index d16aa9ec4e..1fa83b2a8d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -78,6 +78,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 	public static final int DEFAULT_NBATCHES = 1;
 	private static final Boolean DEFAULT_MODELAVG = false;
 	private static final Boolean DEFAULT_HE = false;
+	public static final int DEFAULT_NUM_BACKUP_WORKERS = 1;
 
 	public ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out, String opcode, String istr) {
 		super(op, paramsMap, out, opcode, istr);
@@ -124,6 +125,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		boolean weighting = getWeighting();
 		int seed = getSeed();
 		int nbatches = getNbatches();
+		int numBackupWorkers = getNumBackupWorkers();
 
 		if( LOG.isInfoEnabled() ) {
 			LOG.info("[+] Update Type: " + updateType);
@@ -180,8 +182,9 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 			throw new DMLRuntimeException("can't use homomorphic encryption with weighting");
 		}
 
-		LocalParamServer ps = (LocalParamServer)createPS(PSModeType.FEDERATED, aggFunc, updateType, freq, workerNum, model, aggServiceEC, getValFunction(),
-			getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics), val_features, val_labels, nbatches, modelAvg, use_homomorphic_encryption);
+		LocalParamServer ps = (LocalParamServer) createPS(PSModeType.FEDERATED, aggFunc, updateType, freq, workerNum,
+			model, aggServiceEC, getValFunction(), getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics),
+			val_features, val_labels, nbatches, modelAvg, use_homomorphic_encryption, numBackupWorkers);
 		// Create the local workers
 		int finalNumBatchesPerEpoch = getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
 		List<FederatedPSControlThread> threads = IntStream.range(0, workerNum)
@@ -239,6 +242,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		String updFunc = getParam(PS_UPDATE_FUN);
 		String aggFunc = getParam(PS_AGGREGATION_FUN);
 		int nbatches = getNbatches();
+		int numBackupWorkers = getNumBackupWorkers();
 		boolean modelAvg = Boolean.parseBoolean(getParam(PS_MODELAVG));
 
 		// Get the compiled execution context
@@ -251,7 +255,8 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 
 		// Create the parameter server
 		ListObject model = sec.getListObject(getParam(PS_MODEL));
-		ParamServer ps = createPS(mode, aggFunc, getUpdateType(), getFrequency(), workerNum, model, aggServiceEC, nbatches, modelAvg);
+		ParamServer ps = createPS(mode, aggFunc, getUpdateType(), getFrequency(), workerNum, model, aggServiceEC,
+			nbatches, modelAvg, numBackupWorkers);
 
 		// Get driver host
 		String host = sec.getSparkContext().getConf().get("spark.driver.host");
@@ -340,14 +345,15 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		double rows_per_worker = Math.ceil((float) ec.getMatrixObject(getParam(PS_FEATURES)).getNumRows() / workerNum);
 		int num_batches_per_epoch = (int) Math.ceil(rows_per_worker / getBatchSize());
 		int nbatches = getNbatches();
+		int numBackupWorkers = getNumBackupWorkers();
 
 		// Create the parameter server
 		ListObject model = ec.getListObject(getParam(PS_MODEL));
 		MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null) ? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null;
 		MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ? ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
 		boolean modelAvg = getModelAvg();
-		ParamServer ps = createPS(mode, aggFunc, updateType, freq, workerNum, model, aggServiceEC,
-			getValFunction(), num_batches_per_epoch, val_features, val_labels, nbatches, modelAvg);
+		ParamServer ps = createPS(mode, aggFunc, updateType, freq, workerNum, model, aggServiceEC, getValFunction(),
+			num_batches_per_epoch, val_features, val_labels, nbatches, modelAvg, numBackupWorkers);
 
 		// Create the local workers
 		List<LocalPSWorker> workers = IntStream.range(0, workerNum)
@@ -488,31 +494,44 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 	 * @return parameter server
 	 */
 	private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType,
-										PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, int nbatches, boolean modelAvg)
+										PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, int nbatches, boolean modelAvg, int numBackupWorkers)
 	{
-		return createPS(mode, aggFunc, updateType, freq, workerNum, model, ec, null, -1, null, null, nbatches, modelAvg);
+		return createPS(mode, aggFunc, updateType, freq, workerNum, model, ec, null, -1, null, null, nbatches,
+			modelAvg, numBackupWorkers);
 	}
 
 
 	private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType,
 										PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, String valFunc,
-										int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg)	{
-		return createPS(mode, aggFunc, updateType, freq, workerNum, model, ec, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg, false);
+										int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg, int numBackupWorkers) {
+		return createPS(mode, aggFunc, updateType, freq, workerNum, model, ec, valFunc, numBatchesPerEpoch, valFeatures,
+			valLabels, nbatches, modelAvg, false, numBackupWorkers);
 	}
 
 	// When this creation is used the parameter server is able to validate after each epoch
 	private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType,
 		PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, String valFunc,
-		int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg, boolean use_homomorphic_encryption)
+		int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg,
+		boolean use_homomorphic_encryption, int numBackupWorkers)
 	{
+		if(updateType.isSBP()) {
+			if(numBackupWorkers < 0 || numBackupWorkers >= workerNum)
+				throw new DMLRuntimeException(
+					"Invalid number of backup workers (with #workers=" + workerNum + "): #backup-workers="
+						+ numBackupWorkers);
+			if (numBackupWorkers == 0)
+				LOG.warn("SBP mode with 0 backup workers is the same as choosing BSP mode.");
+		}
 		switch (mode) {
 			case FEDERATED:
 			case LOCAL:
 			case REMOTE_SPARK:
 				if (use_homomorphic_encryption) {
-					return HEParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches);
+					return HEParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc,
+						numBatchesPerEpoch, valFeatures, valLabels, nbatches, numBackupWorkers);
 				} else {
-					return LocalParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg);
+					return LocalParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc,
+						numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg, numBackupWorkers);
 				}
 			default:
 				throw new DMLRuntimeException("Unsupported parameter server: " + mode.name());
@@ -551,6 +570,11 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 				LOG.warn(String.format("There is only %d batches of data but has %d workers. "
 					+ "Hence, reset the number of workers with %d.", pfs.size(), workers.size(), pfs.size()));
 			}
+			if (getUpdateType().isSBP() && pfs.size() <= getNumBackupWorkers()) {
+				throw new DMLRuntimeException(
+					"Effective number of workers is smaller or equal to the number of backup workers."
+						+ " Change partitioning scheme to OVERLAP_RESHUFFLE, decrease number of backup workers or increase number of rows in dataset.");
+			}
 			workers = workers.subList(0, pfs.size());
 		}
 		for (int i = 0; i < workers.size(); i++) {
@@ -635,6 +659,15 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		return Integer.parseInt(getParam(PS_NBATCHES));
 	}
 
+	private int getNumBackupWorkers() {
+		if(!getParameterMap().containsKey(PS_NUM_BACKUP_WORKERS)) {
+			if (!getUpdateType().isSBP())
+				LOG.warn("Specifying number of backup-workers without SBP mode has no effect");
+			return DEFAULT_NUM_BACKUP_WORKERS;
+		}
+		return Integer.parseInt(getParam(PS_NUM_BACKUP_WORKERS));
+	}
+
 	private boolean checkIsPrivate(MatrixObject obj) {
 		PrivacyConstraint pc = obj.getPrivacyConstraint();
 		return pc != null && pc.hasPrivateElements();
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
index 702f632521..b7eb0c2283 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
@@ -68,6 +68,9 @@ public class AvgModelFederatedParamservTest extends AutomatedTestBase {
 			{"TwoNN",	2, 4, 1, 4, 0.01, 		"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"BASELINE",		"true",	"IMBALANCED",	200},
 			{"CNN", 	2, 4, 1, 4, 0.01, 		"BSP", "EPOCH", "SHUFFLE", 				"BASELINE",		"true",	"IMBALANCED", 	200},
 			{"TwoNN", 	5, 1000, 100, 2, 0.01, 	"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"NONE",			"true",	"BALANCED",		200},
+			{"TwoNN",	2, 4, 1, 4, 0.01, 		"SBP", "BATCH", "KEEP_DATA_ON_WORKER", 	"BASELINE",		"true",	"IMBALANCED",	200},
+			{"TwoNN",	2, 4, 1, 4, 0.01, 		"SBP", "BATCH", "KEEP_DATA_ON_WORKER", 	"BASELINE",		"true",	"BALANCED",	200},
+			{"CNN",		2, 4, 1, 4, 0.01, 		"SBP", "EPOCH", "SHUFFLE",			 	"BASELINE",		"true",	"BALANCED",	200},
 
 			/*
 				// runtime balancing
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
index ca50338e33..2a6bbd950d 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
@@ -71,6 +71,9 @@ public class EncryptedFederatedParamservTest extends AutomatedTestBase {
 				{"TwoNN",	2, 4, 1, 1, 0.01, 		"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"BASELINE",		"false",	"IMBALANCED",	200},
 				{"CNN", 	2, 4, 1, 1, 0.01, 		"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "BASELINE",		"false",	"IMBALANCED", 	200},
 				//{"TwoNN", 	5, 1000, 100, 1, 0.01, 	"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"NONE",			"true",	"BALANCED",		200},
+				{"TwoNN",	2, 4, 1, 4, 0.01, 		"SBP", "BATCH", "KEEP_DATA_ON_WORKER", 	"BASELINE",		"false",	"IMBALANCED",	200},
+				{"TwoNN",	2, 4, 1, 4, 0.01, 		"SBP", "BATCH", "KEEP_DATA_ON_WORKER", 	"BASELINE",		"false",	"BALANCED",		200},
+				{"CNN",		2, 4, 1, 4, 0.01, 		"SBP", "EPOCH", "SHUFFLE",			 	"BASELINE",		"false",	"BALANCED",		200},
 
 				/*
                     // runtime balancing
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index fd40275dde..81463b4c54 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -68,6 +68,9 @@ public class FederatedParamservTest extends AutomatedTestBase {
 			{"CNN",		2, 4, 1, 4, 0.01, 		"ASP", "BATCH", "REPLICATE_TO_MAX", 	"CYCLE_MIN", 	"true",	"IMBALANCED",	200},
 			{"TwoNN", 	2, 4, 1, 4, 0.01, 		"ASP", "EPOCH", "BALANCE_TO_AVG", 		"CYCLE_MAX", 	"true",	"IMBALANCED",	200},
 			{"TwoNN", 	5, 1000, 100, 2, 0.01, 	"BSP", "BATCH", "KEEP_DATA_ON_WORKER", 	"NONE", 		"true",	"BALANCED",		200},
+			{"TwoNN",	2, 4, 1, 4, 0.01, 		"SBP", "BATCH", "KEEP_DATA_ON_WORKER", 	"BASELINE",		"true",	"IMBALANCED",	200},
+			{"TwoNN",	2, 4, 1, 4, 0.01, 		"SBP", "BATCH", "KEEP_DATA_ON_WORKER", 	"BASELINE",		"true",	"BALANCED",		200},
+			{"CNN",		2, 4, 1, 4, 0.01, 		"SBP", "EPOCH", "SHUFFLE",			 	"BASELINE",		"true",	"BALANCED",		200},
 
 			/*
 				// runtime balancing
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
index 9b9f9bba85..9b307a4e53 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
@@ -66,6 +66,8 @@ public class NbatchesFederatedParamservTest extends AutomatedTestBase {
 			{"CNN",   2, 2000, 100, 4, 0.01, "BSP", "NBATCHES", "SHUFFLE",             "NONE",      "true", "BALANCED", 200, 8},
 			{"CNN",   2, 480,  32,  4, 0.01, "ASP", "NBATCHES", "REPLICATE_TO_MAX",    "CYCLE_MIN", "true", "BALANCED", 200, 16},
 			{"TwoNN", 5, 2000, 100, 2, 0.01, "BSP", "NBATCHES", "KEEP_DATA_ON_WORKER", "NONE",      "true", "BALANCED", 200, 2},
+			{"CNN",   2, 2000, 100, 4, 0.01, "SBP", "NBATCHES", "SHUFFLE",             "NONE",      "true", "BALANCED", 200, 8},
+			{"TwoNN", 5, 2000, 100, 2, 0.01, "BSP", "NBATCHES", "KEEP_DATA_ON_WORKER", "NONE",      "true", "BALANCED", 200, 2},
 		});
 	}
 
diff --git a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNAveragingTest.java b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNAveragingTest.java
index 396f24fd45..b103adf7ef 100644
--- a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNAveragingTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNAveragingTest.java
@@ -62,6 +62,31 @@ public class ParamservLocalNNAveragingTest extends AutomatedTestBase {
 		runDMLTest(10, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.OVERLAP_RESHUFFLE, true);
 	}
 
+	@Test
+	public void testParamservSBPBatchDisjointContiguous() {
+		runDMLTest(10, 3, Statement.PSUpdateType.SBP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS, true);
+	}
+
+	@Test
+	public void testParamservSBPEpoch() {
+		runDMLTest(10, 3, Statement.PSUpdateType.SBP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS, true);
+	}
+
+	@Test
+	public void testParamservSBPBatchDisjointRoundRobin() {
+		runDMLTest(10, 3, Statement.PSUpdateType.SBP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_ROUND_ROBIN, true);
+	}
+
+	@Test
+	public void testParamservSBPBatchDisjointRandom() {
+		runDMLTest(10, 3, Statement.PSUpdateType.SBP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_RANDOM, true);
+	}
+
+	@Test
+	public void testParamservSBPBatchOverlapReshuffle() {
+		runDMLTest(10, 3, Statement.PSUpdateType.SBP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.OVERLAP_RESHUFFLE, true);
+	}
+
 	private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme, boolean modelAvg) {
 		TestConfiguration config = getTestConfiguration(ParamservLocalNNAveragingTest.TEST_NAME);
 		loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTest.java b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTest.java
index 351e01904f..d371e83ff0 100644
--- a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTest.java
@@ -52,6 +52,11 @@ public class ParamservLocalNNTest extends AutomatedTestBase {
 		runDMLTest(3, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
 	}
 
+	@Test
+	public void testParamservSBPEpoch() {
+		runDMLTest(3, 3, Statement.PSUpdateType.SBP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+	}
+
 	@Test
 	public void testParamservASPEpoch() {
 		runDMLTest(3, 2, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
@@ -67,6 +72,21 @@ public class ParamservLocalNNTest extends AutomatedTestBase {
 		runDMLTest(3, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_RANDOM);
 	}
 
+	@Test
+	public void testParamservSBPBatch() {
+		runDMLTest(3, 3, Statement.PSUpdateType.SBP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+	}
+
+	@Test
+	public void testParamservSBPBatchDisjointRoundRobin() {
+		runDMLTest(3, 3, Statement.PSUpdateType.SBP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_ROUND_ROBIN);
+	}
+
+	@Test
+	public void testParamservSBPBatchDisjointRandom() {
+		runDMLTest(3, 3, Statement.PSUpdateType.SBP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_RANDOM);
+	}
+
 	@Test
 	public void testParamservBSPBatchOverlapReshuffle() {
 		runDMLTest(3, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.OVERLAP_RESHUFFLE);
diff --git a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTestwithNbatches.java b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTestwithNbatches.java
index b0d21c64c5..c8d20f3a7b 100644
--- a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTestwithNbatches.java
+++ b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTestwithNbatches.java
@@ -52,6 +52,16 @@ public class ParamservLocalNNTestwithNbatches extends AutomatedTestBase {
 		runDMLTest(3, 2, Statement.PSUpdateType.ASP, Statement.PSFrequency.NBATCHES, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS, 8, false);
 	}
 
+	@Test
+	public void testParamservSBPNBatchesDisjointContiguous() {
+		runDMLTest(3, 3, Statement.PSUpdateType.SBP, Statement.PSFrequency.NBATCHES, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS, 8, false );
+	}
+
+	@Test
+	public void testParamservSBPNBatchesDisjointContiguousModelAvg() {
+		runDMLTest(3, 3, Statement.PSUpdateType.SBP, Statement.PSFrequency.NBATCHES, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS, 8, true );
+	}
+
 	private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme, int nbatches, boolean modelAvg) {
 		TestConfiguration config = getTestConfiguration(ParamservLocalNNTestwithNbatches.TEST_NAME);
 		loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSBPTest.java b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSBPTest.java
new file mode 100644
index 0000000000..27c87b462c
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSBPTest.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.paramserv;
+
+import org.apache.sysds.parser.Statement;
+import org.junit.Test;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+
+public class ParamservSBPTest extends AutomatedTestBase {
+
+	private static final String TEST_NAME = "paramserv-sbp";
+
+	private static final String TEST_DIR = "functions/paramserv/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + ParamservSBPTest.class.getSimpleName() + "/";
+
+	private final String HOME = SCRIPT_DIR + TEST_DIR;
+
+	@Override
+	public void setUp() {
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {}));
+	}
+
+	@Test
+	public void testParamservNegativeNumBackupWorkers() {
+		runDMLTest(TEST_NAME, "Invalid number of backup workers (with #workers=3): #backup-workers=-1", -1,
+			Statement.PSScheme.OVERLAP_RESHUFFLE);
+	}
+
+	@Test
+	public void testParamservAllBackupWorkers() {
+		runDMLTest(TEST_NAME, "Invalid number of backup workers (with #workers=3): #backup-workers=3", 3,
+			Statement.PSScheme.OVERLAP_RESHUFFLE);
+	}
+
+	@Test
+	public void testParamservTooFewEffectiveWorkers() {
+		runDMLTest(TEST_NAME,
+			"Effective number of workers is smaller or equal to the number of backup workers. Change partitioning scheme to OVERLAP_RESHUFFLE, decrease number of backup workers or increase number of rows in dataset.",
+			1, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+	}
+
+	@Test
+	public void testParamservNormalRun() {
+		runDMLTest(TEST_NAME, null, 1, Statement.PSScheme.OVERLAP_RESHUFFLE);
+	}
+
+	private void runDMLTest(String testname, String errmsg, int numBackupWorkers, Statement.PSScheme scheme) {
+		TestConfiguration config = getTestConfiguration(testname);
+		loadTestConfiguration(config);
+		programArgs = new String[] {"-explain", "-nvargs", "scheme=" + scheme, "workers=3", "backup_workers=" + numBackupWorkers};
+		fullDMLScriptName = HOME + testname + ".dml";
+		boolean exceptionExpected = errmsg != null;
+		runTest(true, exceptionExpected, DMLRuntimeException.class, errmsg, -1);
+	}
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java
index e59e5705ac..c7f0e39dff 100644
--- a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java
@@ -56,6 +56,11 @@ public class ParamservSparkNNTest extends AutomatedTestBase {
 		runDMLTest(2, 2, Statement.PSUpdateType.ASP, Statement.PSFrequency.BATCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
 	}
 
+	@Test
+	public void testParamservSBPBatchDisjointContiguous() {
+		runDMLTest(2, 3, Statement.PSUpdateType.SBP, Statement.PSFrequency.BATCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+	}
+
 	@Test
 	public void testParamservBSPEpochDisjointContiguous() {
 		runDMLTest(5, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
@@ -66,6 +71,11 @@ public class ParamservSparkNNTest extends AutomatedTestBase {
 		runDMLTest(5, 2, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
 	}
 
+	@Test
+	public void testParamservSBPEpochDisjointContiguous() {
+		runDMLTest(2, 3, Statement.PSUpdateType.SBP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+	}
+
 	@Test
 	public void testParamservWorkerFailed() {
 		runDMLTest(TEST_NAME2, true, DMLRuntimeException.class, "Invalid indexing by name in unnamed list: worker_err.");
diff --git a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTestwithNbatches.java b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTestwithNbatches.java
index 11e40a15db..b5faae81db 100644
--- a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTestwithNbatches.java
+++ b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTestwithNbatches.java
@@ -51,6 +51,11 @@ public class ParamservSparkNNTestwithNbatches extends AutomatedTestBase {
 		runDMLTest(2, 2, Statement.PSUpdateType.ASP, Statement.PSFrequency.BATCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS, 16, false);
 	}
 
+	@Test
+	public void testParamservSBPNbatchesDisjointContiguous() {
+		runDMLTest(2, 3, Statement.PSUpdateType.SBP, Statement.PSFrequency.BATCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS, 16, false);
+	}
+
 	private void internalRunDMLTest(String testname, boolean exceptionExpected, Class<?> expectedException,
 		String errMessage) {
 		ExecMode oldRtplatform = AutomatedTestBase.rtplatform;
@@ -71,7 +76,7 @@ public class ParamservSparkNNTestwithNbatches extends AutomatedTestBase {
 	}
 
 	private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme, int nbatches, boolean modelAvg) {
-		programArgs = new String[] { "-nvargs", "mode=REMOTE_SPARK", "epochs=" + epochs, "workers=" + workers, "utype=" + utype, "freq=" + freq, "batchsize=" + batchsize, "scheme=" + scheme + "nbatches=" + nbatches, "modelAvg=" + modelAvg};
+		programArgs = new String[] { "-nvargs", "mode=REMOTE_SPARK", "epochs=" + epochs, "workers=" + workers, "utype=" + utype, "freq=" + freq, "batchsize=" + batchsize, "scheme=" + scheme, "nbatches=" + nbatches, "modelAvg=" + modelAvg};
 		internalRunDMLTest(TEST_NAME1, false, null, null);
 	}
 }
diff --git a/src/test/scripts/functions/paramserv/paramserv-all-args.dml b/src/test/scripts/functions/paramserv/paramserv-all-args.dml
index ec6e087484..3edbffbac1 100644
--- a/src/test/scripts/functions/paramserv/paramserv-all-args.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-all-args.dml
@@ -38,6 +38,6 @@ e2 = "element2"
 hps = list(e2=e2)
 
 # Use paramserv function
-paramsList2 = paramserv(model=paramsList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=hps, checkpointing="NONE")
+paramsList2 = paramserv(model=paramsList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=hps, checkpointing="NONE", num_backup_workers=1)
 
 print(length(paramsList2))
\ No newline at end of file
diff --git a/src/test/scripts/functions/paramserv/paramserv-all-args.dml b/src/test/scripts/functions/paramserv/paramserv-sbp.dml
similarity index 50%
copy from src/test/scripts/functions/paramserv/paramserv-all-args.dml
copy to src/test/scripts/functions/paramserv/paramserv-sbp.dml
index ec6e087484..6d48f37674 100644
--- a/src/test/scripts/functions/paramserv/paramserv-all-args.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-sbp.dml
@@ -19,25 +19,32 @@
 #
 #-------------------------------------------------------------
 
-e1 = "element1"
-paramsList = list(e1=e1)
-X = matrix(1, rows=2, cols=3)
-Y = matrix(2, rows=2, cols=3)
-X_val = matrix(3, rows=2, cols=3)
-Y_val = matrix(4, rows=2, cols=3)
-
-gradients = function (matrix[double] input) return (matrix[double] output) {
-  output = input
+gradients = function(list[unknown] model,
+                     list[unknown] hyperparams,
+                     matrix[double] features,
+                     matrix[double] labels)
+          return (list[unknown] gradients) {
+  gradients = model;
 }
 
-aggregation = function (matrix[double] input) return (matrix[double] output) {
-  output = input
+aggregation = function(list[unknown] model,
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
+   return (list[unknown] modelResult) {
+  modelResult = model;
 }
 
-e2 = "element2"
-hps = list(e2=e2)
+model = list(matrix(0, 2, 3))
+X = matrix(1, 2, 3)
+Y = matrix(2, 2, 3)
+X_val = matrix(3, 2, 3)
+Y_val = matrix(4, 2, 3)
+hps = list()
 
 # Use paramserv function
-paramsList2 = paramserv(model=paramsList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=hps, checkpointing="NONE")
+supd = ".defaultNS::gradients";
+sagg = ".defaultNS::aggregation";
+model = paramserv(model=model, k=$workers, features=X, labels=Y, val_features=X_val, val_labels=Y_val,
+  upd=supd, agg=sagg, mode="LOCAL", utype="SBP", freq="EPOCH", scheme=$scheme, epochs=1, hyperparams=hps, num_backup_workers=$backup_workers)
 
-print(length(paramsList2))
\ No newline at end of file
+print(toString(model))