You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/09/10 21:08:14 UTC

[systemds] branch master updated: [SYSTEMDS-3018] Extended parameter server w/ nbatch update frequency

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new dadc901  [SYSTEMDS-3018] Extended parameter server w/ nbatch update frequency
dadc901 is described below

commit dadc9012f857863596cd8f124d64f2b66eb73dc6
Author: Atefeh Asayesh <at...@gmail.com>
AuthorDate: Fri Sep 10 23:06:52 2021 +0200

    [SYSTEMDS-3018] Extended parameter server w/ nbatch update frequency
    
    Extended local, spark, federated parameter server.
    Closes #1382.
---
 .../ParameterizedBuiltinFunctionExpression.java    |   2 +-
 .../java/org/apache/sysds/parser/Statement.java    |   3 +-
 .../paramserv/FederatedPSControlThread.java        |  90 ++--
 .../controlprogram/paramserv/LocalPSWorker.java    |  45 +-
 .../controlprogram/paramserv/LocalParamServer.java |   8 +-
 .../runtime/controlprogram/paramserv/PSWorker.java |   4 +-
 .../controlprogram/paramserv/ParamServer.java      |  11 +-
 .../controlprogram/paramserv/SparkPSWorker.java    |   3 +-
 .../cp/ParamservBuiltinCPInstruction.java          |  32 +-
 .../paramserv/NbatchesFederatedParamservTest.java  | 226 ++++++++++
 .../functions/paramserv/ParamservLocalNNTest.java  |  14 +-
 ....java => ParamservLocalNNTestwithNbatches.java} |  48 +--
 .../ParamservSparkNNTestwithNbatches.java          |  77 ++++
 .../federated/paramserv/CNNwithNbatches.dml        | 471 +++++++++++++++++++++
 .../FederatedParamservTestwithNbatches.dml         |  40 ++
 .../paramserv/NbatchesFederatedParamservTest.dml   |  40 ++
 .../federated/paramserv/TwoNNwithNbatches.dml      | 305 +++++++++++++
 .../paramserv/mnist_lenet_paramserv_nbatches.dml   | 372 ++++++++++++++++
 .../paramserv/paramserv-nbatches-test.dml          |  49 +++
 19 files changed, 1739 insertions(+), 101 deletions(-)

diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 1f5fd16..609de41 100644
--- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -315,7 +315,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
 			Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
 			Statement.PS_VAL_FUN, Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
 			Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
-			Statement.PS_FED_WEIGHTING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING, Statement.PS_SEED, Statement.PS_MODELAVG);
+			Statement.PS_FED_WEIGHTING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING, Statement.PS_SEED, Statement.PS_NBATCHES, Statement.PS_MODELAVG);
 		checkInvalidParameters(getOpCode(), getVarParams(), valid);
 
 		// check existence and correctness of parameters
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java
index f9f5911..4b0237d 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -73,6 +73,7 @@ public abstract class Statement implements ParseInfo
 	public static final String PS_GRADIENTS = "gradients";
 	public static final String PS_SEED = "seed";
 	public static final String PS_MODELAVG = "modelAvg";
+	public static final String PS_NBATCHES = "nbatches";
 	public enum PSModeType {
 		FEDERATED, LOCAL, REMOTE_SPARK
 	}
@@ -88,7 +89,7 @@ public abstract class Statement implements ParseInfo
 	}
 	public static final String PS_FREQUENCY = "freq";
 	public enum PSFrequency {
-		BATCH, EPOCH
+		BATCH, EPOCH, NBATCHES
 	}
 	public static final String PS_FED_WEIGHTING = "weighting";
 	public static final String PS_FED_RUNTIME_BALANCING = "runtime_balancing";
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 07ac212..85bb745 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -78,6 +78,7 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 	// runtime balancing
 	private final PSRuntimeBalancing _runtimeBalancing;
 	private int _numBatchesPerEpoch;
+	private int _numBatchesPerNbatch;
 	private int _possibleBatchesPerLocalEpoch;
 	private final boolean _weighting;
 	private double _weightingFactor = 1;
@@ -85,13 +86,14 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 
 	public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq,
 		PSRuntimeBalancing runtimeBalancing, boolean weighting, int epochs, long batchSize,
-		int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps, boolean modelAvg)
+		int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps, int nbatches, boolean modelAvg)
 	{
-		super(workerID, updFunc, freq, epochs, batchSize, ec, ps, modelAvg);
+		super(workerID, updFunc, freq, epochs, batchSize, ec, ps, nbatches, modelAvg);
 
 		_numBatchesPerEpoch = numBatchesPerGlobalEpoch;
 		_runtimeBalancing = runtimeBalancing;
 		_weighting = weighting;
+		_numBatchesPerNbatch = nbatches;
 		// generate the ID for the model
 		_modelVarID = FederationUtils.getNextFedDataID();
 		_modelAvg = modelAvg;
@@ -114,7 +116,6 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 
 		// different runtime balancing calculations
 		long dataSize = _features.getNumRows();
-
 		// calculate scaled batch size if balancing via batch size.
 		// In some cases there will be some cycling
 		if(_runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH)
@@ -147,7 +148,7 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 		gradientProgramBlock.setInstructions(new ArrayList<>(Collections.singletonList(_inst)));
 		pbs.add(gradientProgramBlock);
 
-		if(_freq == PSFrequency.EPOCH) {
+		if(_freq == PSFrequency.EPOCH || _freq == PSFrequency.NBATCHES) {
 			BasicProgramBlock aggProgramBlock = new BasicProgramBlock(_ec.getProgram());
 			aggProgramBlock.setInstructions(new ArrayList<>(Collections.singletonList(_ps.getAggInst())));
 			pbs.add(aggProgramBlock);
@@ -164,7 +165,7 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 				new SetupFederatedWorker(_batchSize, dataSize, _possibleBatchesPerLocalEpoch,
 					programSerialized, _inst.getNamespace(), _inst.getFunctionName(),
 					_ps.getAggInst().getFunctionName(), _ec.getListObject("hyperparams"),
-					_modelVarID, _modelAvg)));
+					_modelVarID, _nbatches, _modelAvg)));
 
 		try {
 			FederatedResponse response = udfResponse.get();
@@ -177,6 +178,26 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 	}
 
 	/**
+	 * cleans up the execution context of the federated worker
+	 */
+	public void teardown() {
+		// write program and meta data to worker
+		Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(
+			new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(),
+			new TeardownFederatedWorker()
+		));
+
+		try {
+			FederatedResponse response = udfResponse.get();
+			if(!response.isSuccessful())
+				throw new DMLRuntimeException("FederatedLocalPSThread: Teardown UDF failed");
+		}
+		catch(Exception e) {
+			throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Teardown UDF" + e.getMessage());
+		}
+	}
+	
+	/**
 	 * Setup UDF executed on the federated worker
 	 */
 	private static class SetupFederatedWorker extends FederatedUDF {
@@ -191,10 +212,11 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 		private final ListObject _hyperParams;
 		private final long _modelVarID;
 		private final boolean _modelAvg;
+		private final int _nbatches;
 
 		protected SetupFederatedWorker(long batchSize, long dataSize, int possibleBatchesPerLocalEpoch,
 			String programString, String namespace, String gradientsFunctionName, String aggregationFunctionName,
-			ListObject hyperParams, long modelVarID, boolean modelAvg)
+			ListObject hyperParams, long modelVarID, int nbatches, boolean modelAvg)
 		{
 			super(new long[]{});
 			_batchSize = batchSize;
@@ -207,6 +229,7 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 			_hyperParams = hyperParams;
 			_modelVarID = modelVarID;
 			_modelAvg = modelAvg;
+			_nbatches = nbatches;
 		}
 
 		@Override
@@ -223,6 +246,7 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 			ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new StringObject(_aggregationFunctionName));
 			ec.setVariable(Statement.PS_HYPER_PARAMS, _hyperParams);
 			ec.setVariable(Statement.PS_FED_MODEL_VARID, new IntObject(_modelVarID));
+			ec.setVariable(Statement.PS_NBATCHES, new IntObject(_nbatches));
 			ec.setVariable(Statement.PS_MODELAVG, new BooleanObject(_modelAvg));
 
 			return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
@@ -235,26 +259,6 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 	}
 
 	/**
-	 * cleans up the execution context of the federated worker
-	 */
-	public void teardown() {
-		// write program and meta data to worker
-		Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(
-			new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(),
-			new TeardownFederatedWorker()
-		));
-
-		try {
-			FederatedResponse response = udfResponse.get();
-			if(!response.isSuccessful())
-				throw new DMLRuntimeException("FederatedLocalPSThread: Teardown UDF failed");
-		}
-		catch(Exception e) {
-			throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Teardown UDF" + e.getMessage());
-		}
-	}
-
-	/**
 	 * Teardown UDF executed on the federated worker
 	 */
 	private static class TeardownFederatedWorker extends FederatedUDF {
@@ -298,9 +302,9 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 				case BATCH:
 					computeWithBatchUpdates();
 					break;
-				/*case NBATCH:
+				case NBATCHES:
 					computeWithNBatchUpdates();
-					break; */
+					break;
 				case EPOCH:
 					computeWithEpochUpdates();
 					break;
@@ -361,6 +365,26 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 	}
 
 	/**
+	 * Computes all epochs and updates after N batches
+	 */
+	protected void computeWithNBatchUpdates() {
+		int numSetsPerEpocNbatches = (int) Math.ceil((double)_numBatchesPerEpoch / _numBatchesPerNbatch);
+		for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) {
+			int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : _numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
+
+			for (int batchCounter = 0; batchCounter < numSetsPerEpocNbatches; batchCounter++) {
+				int localStartBatchNum = getNextLocalBatchNum(currentLocalBatchNumber, numSetsPerEpocNbatches);
+				currentLocalBatchNumber = currentLocalBatchNumber + _numBatchesPerNbatch;
+				ListObject model = pullModel();
+				ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerNbatch, localStartBatchNum, true);
+				weightAndPushGradients(gradients);
+				ParamservUtils.cleanupListObject(model);
+				ParamservUtils.cleanupListObject(gradients);
+			}
+		}
+	}
+
+	/**
 	 * Computes all epochs and updates after each epoch
 	 */
 	protected void computeWithEpochUpdates() {
@@ -368,7 +392,6 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 			int localStartBatchNum = (_cycleStartAt0) ? 0 : _numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
 
 			// Pull the global parameters from ps
-			// TODO double check if model averaging is handled correctly (internally?)
 			ListObject model = pullModel();
 			ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum, true);
 			weightAndPushGradients(gradients);
@@ -463,7 +486,6 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 			String gradientsFunc = ((StringObject) ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
 			String aggFunc = ((StringObject) ec.getVariable(Statement.PS_FED_AGGREGATION_FNAME)).getStringValue();
 			boolean modelAvg = ((BooleanObject) ec.getVariable(Statement.PS_MODELAVG)).getBooleanValue();
-
 			// recreate gradient instruction and output
 			boolean opt = !ec.getProgram().containsFunctionProgramBlock(namespace, gradientsFunc, false);
 			FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(namespace, gradientsFunc, opt);
@@ -494,7 +516,6 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 					opt, boundInputs, func.getInputParamNames(), outputNames, "aggregation function");
 				aggregationOutput = outputs.get(0);
 			}
-
 			ListObject accGradients = null;
 			int currentLocalBatchNumber = _localStartBatchNum;
 			// prepare execution context
@@ -515,18 +536,17 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
 				// calculate gradients for batch
 				gradientsInstruction.processInstruction(ec);
 				ListObject gradients = ec.getListObject(gradientsOutput.getName());
-				
+
 				// accrue the computed gradients - In the single batch case this is just a list copy
 				// is this equivalent for momentum based and AMS prob?
 				accGradients = modelAvg ? null :
 					ParamservUtils.accrueGradients(accGradients, gradients, false);
-				
+
 				// update the local model with gradients if needed
 				// FIXME ensure that with modelAvg we always update the model
 				// (current fails due to missing aggregation instruction)
-				if(_localUpdate && batchCounter < _numBatchesToCompute - 1) {
+				if(_localUpdate && (batchCounter < _numBatchesToCompute - 1 | modelAvg) ) {
 					// Invoke the aggregate function
-					assert aggregationInstruction != null;
 					aggregationInstruction.processInstruction(ec);
 					// Get the new model
 					model = ec.getListObject(aggregationOutput.getName());
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
index fef617b..f1848ad 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -43,9 +43,9 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 	protected LocalPSWorker() {}
 
 	public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq,
-		int epochs, long batchSize, ExecutionContext ec, ParamServer ps, boolean modelAvg)
+		int epochs, long batchSize, ExecutionContext ec, ParamServer ps, int nbatches, boolean modelAvg)
 	{
-		super(workerID, updFunc, freq, epochs, batchSize, ec, ps, modelAvg);
+		super(workerID, updFunc, freq, epochs, batchSize, ec, ps, nbatches, modelAvg);
 	}
 
 	@Override
@@ -67,6 +67,9 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 				case EPOCH:
 					computeEpoch(dataSize, batchIter);
 					break;
+				case NBATCHES:
+					computeNBatches(dataSize, batchIter);
+					break;
 				default:
 					throw new DMLRuntimeException(String.format("%s not support update frequency %s", getWorkerName(), _freq));
 			}
@@ -120,6 +123,44 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 		}
 	}
 
+	private void computeNBatches(long dataSize, int batchIter) {
+		ListObject model = null;
+		Future<ListObject> accGradients = ConcurrentUtils.constantFuture(null);
+		for(int i = 0; i < _epochs; i++) {
+			try {
+				for(int j = 0; j < batchIter; j++) {
+					boolean localUpdate = j < batchIter;
+					if( j % _nbatches == 0 )
+						model = pullModel();
+					ListObject gradients = computeGradients(model, dataSize, batchIter, i, j);
+					// Accumulate the intermediate gradients (async for overlap w/ model updates
+					// and gradient computation, sequential over gradient matrices to avoid deadlocks)
+					ListObject accGradientsPrev = accGradients.get();
+					accGradients = _tpool
+						.submit(() -> ParamservUtils.accrueGradients(accGradientsPrev, gradients, false, !localUpdate));
+					// Update the local model with gradients
+					if(localUpdate | _modelAvg)
+						model = updateModel(model, gradients, i, j, batchIter);
+					accNumBatches(1);
+					
+					// Push the gradients to ps
+					if((j % _nbatches == (_nbatches-1)) || (j == batchIter-1)) {
+						pushGradients(_modelAvg ? model : accGradients.get());
+						accGradients = ConcurrentUtils.constantFuture(null);
+					}
+					accNumBatches(1);
+				}
+			}
+			catch(ExecutionException | InterruptedException ex) {
+				throw new DMLRuntimeException(ex);
+			}
+			accNumEpochs(1);
+			if(LOG.isDebugEnabled()) {
+				LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1));
+			}
+		}
+	}
+
 	private ListObject updateModel(ListObject globalParams, ListObject gradients, int i, int j, int batchIter) {
 		Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null;
 
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
index ebf6698..50c76a0 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -33,17 +33,17 @@ public class LocalParamServer extends ParamServer {
 
 	public static LocalParamServer create(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
 		Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
-		MatrixObject valFeatures, MatrixObject valLabels, boolean modelAvg)
+		MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg)
 	{
 		return new LocalParamServer(model, aggFunc, updateType, freq, ec,
-			workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, modelAvg);
+			workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg);
 	}
 
 	private LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
 		Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
-		MatrixObject valFeatures, MatrixObject valLabels, boolean modelAvg)
+		MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg)
 	{
-		super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, modelAvg);
+		super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg);
 	}
 
 	@Override
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
index a1e55f3..b146e65 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
@@ -57,11 +57,12 @@ public abstract class PSWorker implements Serializable
 	protected MatrixObject _labels;
 	protected String _updFunc;
 	protected Statement.PSFrequency _freq;
+	protected int _nbatches;
 	protected boolean _modelAvg;
 
 	protected PSWorker() {}
 
-	protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps, boolean modelAvg) {
+	protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps, int nbatches, boolean modelAvg) {
 		_workerID = workerID;
 		_updFunc = updFunc;
 		_freq = freq;
@@ -69,6 +70,7 @@ public abstract class PSWorker implements Serializable
 		_batchSize = batchSize;
 		_ec = ec;
 		_ps = ps;
+		_nbatches = nbatches;
 		_modelAvg = modelAvg;
 		setupUpdateFunction(updFunc, ec);
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index dc5b85f..a27aafc 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -86,7 +86,7 @@ public abstract class ParamServer
 	protected ParamServer() {}
 
 	protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
-		Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,	MatrixObject valFeatures, MatrixObject valLabels, boolean modelAvg)
+		Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,	MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg)
 	{
 		// init worker queues and global model
 		_modelMap = new HashMap<>(workerNum);
@@ -222,9 +222,10 @@ public abstract class ParamServer
 						// BSP epoch case every time
 						if (_numBatchesPerEpoch != -1 &&
 							(_freq == Statement.PSFrequency.EPOCH ||
-							(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+							(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))||
+							(_freq == Statement.PSFrequency.NBATCHES)) {
 
-							if(LOG.isInfoEnabled())
+						if(LOG.isInfoEnabled())
 								LOG.info("[+] PARAMSERV: completed EPOCH " + _epochCounter);
 
 							time_epoch();
@@ -250,8 +251,8 @@ public abstract class ParamServer
 					// the number of workers, creating "Pseudo Epochs"
 					if (_numBatchesPerEpoch != -1 &&
 						((_freq == Statement.PSFrequency.EPOCH && ((float) ++_syncCounter % _numWorkers) == 0) ||
-						(_freq == Statement.PSFrequency.BATCH && ((float) ++_syncCounter / _numWorkers) % (float) _numBatchesPerEpoch == 0))) {
-
+						(_freq == Statement.PSFrequency.BATCH && ((float) ++_syncCounter / _numWorkers) % (float) _numBatchesPerEpoch == 0)) ||
+						(_freq == Statement.PSFrequency.NBATCHES)) {
 						if(LOG.isInfoEnabled())
 							LOG.info("[+] PARAMSERV: completed PSEUDO EPOCH (ASP) " + _epochCounter);
 
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
index 1f3cd1a..8b2fd96 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
@@ -54,7 +54,7 @@ public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<
 	private final LongAccumulator _nBatches; //number of executed batches
 	private final LongAccumulator _nEpochs; //number of executed epoches
 
-	public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches, LongAccumulator aEpochs, boolean modelAvg) {
+	public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches, LongAccumulator aEpochs, int nbatches, boolean modelAvg) {
 		_updFunc = updFunc;
 		_aggFunc = aggFunc;
 		_freq = freq;
@@ -72,6 +72,7 @@ public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<
 		_aRPC = aRPC;
 		_nBatches = aBatches;
 		_nEpochs = aEpochs;
+		_nbatches = nbatches;
 		_modelAvg = modelAvg;
 	}
 
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 53987a0..ab60813 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -39,6 +39,7 @@ import static org.apache.sysds.parser.Statement.PS_HYPER_PARAMS;
 import static org.apache.sysds.parser.Statement.PS_LABELS;
 import static org.apache.sysds.parser.Statement.PS_MODE;
 import static org.apache.sysds.parser.Statement.PS_MODEL;
+import static org.apache.sysds.parser.Statement.PS_NBATCHES;
 import static org.apache.sysds.parser.Statement.PS_MODELAVG;
 import static org.apache.sysds.parser.Statement.PS_PARALLELISM;
 import static org.apache.sysds.parser.Statement.PS_SCHEME;
@@ -98,6 +99,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 	private static final FederatedPSScheme DEFAULT_FEDERATED_SCHEME = FederatedPSScheme.KEEP_DATA_ON_WORKER;
 	private static final PSModeType DEFAULT_MODE = PSModeType.LOCAL;
 	private static final PSUpdateType DEFAULT_TYPE = PSUpdateType.ASP;
+	public static final int DEFAULT_NBATCHES = 1;
 	private static final Boolean DEFAULT_MODELAVG = false;
 
 	public ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out, String opcode, String istr) {
@@ -144,6 +146,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
 		boolean weighting = getWeighting();
 		int seed = getSeed();
+		int nbatches = getNbatches();
 
 		if( LOG.isInfoEnabled() ) {
 			LOG.info("[+] Update Type: " + updateType);
@@ -185,12 +188,12 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ? ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
 		boolean modelAvg = Boolean.parseBoolean(getParam(PS_MODELAVG));
 		ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, updateType, freq, workerNum, model, aggServiceEC, getValFunction(),
-			getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics), val_features, val_labels, modelAvg);
+			getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics), val_features, val_labels, nbatches, modelAvg);
 		// Create the local workers
 		int finalNumBatchesPerEpoch = getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
 		List<FederatedPSControlThread> threads = IntStream.range(0, workerNum)
 			.mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, weighting,
-				getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps, modelAvg))
+				getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps, nbatches, modelAvg))
 			.collect(Collectors.toList());
 		if(workerNum != threads.size()) {
 			throw new DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning does not match threads!");
@@ -226,6 +229,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		int workerNum = getWorkerNum(mode);
 		String updFunc = getParam(PS_UPDATE_FUN);
 		String aggFunc = getParam(PS_AGGREGATION_FUN);
+		int nbatches = getNbatches();
 		boolean modelAvg = Boolean.parseBoolean(getParam(PS_MODELAVG));
 
 		// Get the compiled execution context
@@ -238,7 +242,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 
 		// Create the parameter server
 		ListObject model = sec.getListObject(getParam(PS_MODEL));
-		ParamServer ps = createPS(mode, aggFunc, getUpdateType(), getFrequency(), workerNum, model, aggServiceEC, modelAvg);
+		ParamServer ps = createPS(mode, aggFunc, getUpdateType(), getFrequency(), workerNum, model, aggServiceEC, nbatches, modelAvg);
 
 		// Get driver host
 		String host = sec.getSparkContext().getConf().get("spark.driver.host");
@@ -268,7 +272,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		// Create remote workers
 		SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN),
 			getFrequency(), getEpochs(), getBatchSize(), program, clsMap, sec.getSparkContext().getConf(),
-			server.getPort(), aSetup, aWorker, aUpdate, aIndex, aGrad, aRPC, aBatch, aEpoch, modelAvg);
+			server.getPort(), aSetup, aWorker, aUpdate, aIndex, aGrad, aRPC, aBatch, aEpoch, nbatches, modelAvg);
 
 		if (DMLScript.STATISTICS)
 			Statistics.accPSSetupTime((long) tSetup.stop());
@@ -325,6 +329,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 
 		double rows_per_worker = Math.ceil((float) ec.getMatrixObject(getParam(PS_FEATURES)).getNumRows() / workerNum);
 		int num_batches_per_epoch = (int) Math.ceil(rows_per_worker / getBatchSize());
+		int nbatches = getNbatches();
 
 		// Create the parameter server
 		ListObject model = ec.getListObject(getParam(PS_MODEL));
@@ -332,12 +337,12 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ? ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
 		boolean modelAvg = getModelAvg();
 		ParamServer ps = createPS(mode, aggFunc, updateType, freq, workerNum, model, aggServiceEC,
-			getValFunction(), num_batches_per_epoch, val_features, val_labels, modelAvg);
+			getValFunction(), num_batches_per_epoch, val_features, val_labels, nbatches, modelAvg);
 
 		// Create the local workers
 		List<LocalPSWorker> workers = IntStream.range(0, workerNum)
 			.mapToObj(i -> new LocalPSWorker(i, updFunc, freq,
-				getEpochs(), getBatchSize(), workerECs.get(i), ps, modelAvg))
+				getEpochs(), getBatchSize(), workerECs.get(i), ps, nbatches, modelAvg))
 			.collect(Collectors.toList());
 
 		// Do data partition
@@ -473,21 +478,21 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 	 * @return parameter server
 	 */
 	private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType,
-		PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, boolean modelAvg)
+		PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, int nbatches, boolean modelAvg)
 	{
-		return createPS(mode, aggFunc, updateType, freq, workerNum, model, ec, null, -1, null, null, modelAvg);
+		return createPS(mode, aggFunc, updateType, freq, workerNum, model, ec, null, -1, null, null, nbatches, modelAvg);
 	}
 
 	// When this creation is used the parameter server is able to validate after each epoch
 	private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType,
 		PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, String valFunc,
-		int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, boolean modelAvg)
+		int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg)
 	{
 		switch (mode) {
 			case FEDERATED:
 			case LOCAL:
 			case REMOTE_SPARK:
-				return LocalParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, modelAvg);
+				return LocalParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg);
 			default:
 				throw new DMLRuntimeException("Unsupported parameter server: " + mode.name());
 		}
@@ -601,4 +606,11 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 			return DEFAULT_MODELAVG;
 		return Boolean.parseBoolean(getParam(PS_MODELAVG));
 	}
+
+	private int getNbatches() {
+		if(!getParameterMap().containsKey(PS_NBATCHES)) {
+			return DEFAULT_NBATCHES;
+		}
+		return Integer.parseInt(getParam(PS_NBATCHES));
+	}
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
new file mode 100644
index 0000000..e2e4f20
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.paramserv;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class NbatchesFederatedParamservTest extends AutomatedTestBase {
+	private static final Log LOG = LogFactory.getLog(NbatchesFederatedParamservTest.class.getName());
+	private final static String TEST_DIR = "functions/federated/paramserv/";
+	private final static String TEST_NAME = "NbatchesFederatedParamservTest";
+	private final static String TEST_CLASS_DIR = TEST_DIR + NbatchesFederatedParamservTest.class.getSimpleName() + "/";
+
+	private final String _networkType;
+	private final int _numFederatedWorkers;
+	private final int _dataSetSize;
+	private final int _epochs;
+	private final int _batch_size;
+	private final double _eta;
+	private final String _utype;
+	private final String _freq;
+	private final String _scheme;
+	private final String _runtime_balancing;
+	private final String _weighting;
+	private final String _data_distribution;
+	private final int _seed;
+	private final int _nbatches;
+
+	// parameters
+	@Parameterized.Parameters
+	public static Collection<Object[]> parameters() {
+		return Arrays.asList(new Object[][] {
+			// Network type, number of federated workers, data set size, batch size, epochs, learning rate, update type, update frequency, number of batches per nbatches
+			// basic functionality
+			{"CNN",   2, 2000, 100, 4, 0.01, "BSP", "NBATCHES", "SHUFFLE",             "NONE",      "true", "BALANCED", 200, 8},
+			{"CNN",   2, 480,  32,  4, 0.01, "ASP", "NBATCHES", "REPLICATE_TO_MAX",    "CYCLE_MIN", "true", "BALANCED", 200, 16},
+			{"TwoNN", 5, 2000, 100, 2, 0.01, "BSP", "NBATCHES", "KEEP_DATA_ON_WORKER", "NONE",      "true", "BALANCED", 200, 2},
+		});
+	}
+
+	public NbatchesFederatedParamservTest(String networkType, int numFederatedWorkers, int dataSetSize, int batch_size,
+		int epochs, double eta, String utype, String freq, String scheme, String runtime_balancing, String weighting, String data_distribution, int seed, int nbatches) {
+
+		_networkType = networkType;
+		_numFederatedWorkers = numFederatedWorkers;
+		_dataSetSize = dataSetSize;
+		_batch_size = batch_size;
+		_epochs = epochs;
+		_eta = eta;
+		_utype = utype;
+		_freq = freq;
+		_scheme = scheme;
+		_runtime_balancing = runtime_balancing;
+		_weighting = weighting;
+		_data_distribution = data_distribution;
+		_seed = seed;
+		_nbatches = nbatches;
+	}
+
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+	}
+
+	@Test
+	public void federatedParamservSingleNodeWithNBatches() {
+		runFederatedParamservTest(ExecMode.SINGLE_NODE);
+	}
+
+	@Test
+	public void federatedParamservHybridWithNBatches() {
+		runFederatedParamservTest(ExecMode.HYBRID);
+	}
+
+	private void runFederatedParamservTest(ExecMode mode) {
+		// Warning Statistics accumulate in unit test
+		// config
+		getAndLoadTestConfiguration(TEST_NAME);
+		String HOME = SCRIPT_DIR + TEST_DIR;
+		setOutputBuffering(false);
+
+		int C = 1, Hin = 28, Win = 28;
+		int numLabels = 10;
+
+		ExecMode platformOld = setExecMode(mode);
+
+		try {
+			// start threads
+			List<Integer> ports = new ArrayList<>();
+			List<Thread> threads = new ArrayList<>();
+			for(int i = 0; i < _numFederatedWorkers; i++) {
+				ports.add(getRandomAvailablePort());
+				threads.add(startLocalFedWorkerThread(ports.get(i), 
+					(i==(_numFederatedWorkers-1) ? FED_WORKER_WAIT : FED_WORKER_WAIT_S)));
+			}
+
+			// generate test data
+			double[][] features = generateDummyMNISTFeatures(_dataSetSize, C, Hin, Win);
+			double[][] labels = generateDummyMNISTLabels(_dataSetSize, numLabels);
+			String featuresName = "";
+			String labelsName = "";
+
+			// federate test data balanced or imbalanced
+			if(_data_distribution.equals("IMBALANCED")) {
+				featuresName = "X_IMBALANCED_" + _numFederatedWorkers;
+				labelsName = "y_IMBALANCED_" + _numFederatedWorkers;
+				double[][] ranges = {{0,1}, {1,4}};
+				rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features, _numFederatedWorkers, ports, ranges);
+				rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, _numFederatedWorkers, ports, ranges);
+			}
+			else {
+				featuresName = "X_BALANCED_" + _numFederatedWorkers;
+				labelsName = "y_BALANCED_" + _numFederatedWorkers;
+				double[][] ranges = generateBalancedFederatedRowRanges(_numFederatedWorkers, features.length);
+				rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features, _numFederatedWorkers, ports, ranges);
+				rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, _numFederatedWorkers, ports, ranges);
+			}
+
+			try {
+				//wait for all workers to be setup
+				Thread.sleep(FED_WORKER_WAIT);
+			}
+			catch(InterruptedException e) {
+				e.printStackTrace();
+			}
+
+			// dml name
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			// generate program args
+			List<String> programArgsList = new ArrayList<>(Arrays.asList("-stats",
+				"-nvargs",
+				"features=" + input(featuresName),
+				"labels=" + input(labelsName),
+				"epochs=" + _epochs,
+				"batch_size=" + _batch_size,
+				"eta=" + _eta,
+				"utype=" + _utype,
+				"freq=" + _freq,
+				"scheme=" + _scheme,
+				"runtime_balancing=" + _runtime_balancing,
+				"weighting=" + _weighting,
+				"network_type=" + _networkType,
+				"channels=" + C,
+				"hin=" + Hin,
+				"win=" + Win,
+				"seed=" + _seed,
+				"nbatches=" + _nbatches));
+
+			programArgs = programArgsList.toArray(new String[0]);
+			LOG.debug(runTest(null));
+			Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst());
+
+			// shut down threads
+			for(int i = 0; i < _numFederatedWorkers; i++) {
+				TestUtils.shutdownThreads(threads.get(i));
+			}
+		}
+		finally {
+			resetExecMode(platformOld);
+		}
+	}
+
+	/**
+	 * Generates an feature matrix that has the same format as the MNIST dataset,
+	 * but is completely random and normalized
+	 *
+	 *  @param numExamples Number of examples to generate
+	 *  @param C Channels in the input data
+	 *  @param Hin Height in Pixels of the input data
+	 *  @param Win Width in Pixels of the input data
+	 *  @return a dummy MNIST feature matrix
+	 */
+	private double[][] generateDummyMNISTFeatures(int numExamples, int C, int Hin, int Win) {
+		// Seed -1 takes the time in milliseconds as a seed
+		// Sparsity 1 means no sparsity
+		return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1);
+	}
+
+	/**
+	 * Generates an label matrix that has the same format as the MNIST dataset, but is completely random and consists
+	 * of one hot encoded vectors as rows
+	 *
+	 *  @param numExamples Number of examples to generate
+	 *  @param numLabels Number of labels to generate
+	 *  @return a dummy MNIST lable matrix
+	 */
+	private double[][] generateDummyMNISTLabels(int numExamples, int numLabels) {
+		// Seed -1 takes the time in milliseconds as a seed
+		// Sparsity 1 means no sparsity
+		return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1);
+	}
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTest.java b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTest.java
index 0e2f957..351e019 100644
--- a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTest.java
@@ -39,37 +39,37 @@ public class ParamservLocalNNTest extends AutomatedTestBase {
 
 	@Test
 	public void testParamservBSPBatchDisjointContiguous() {
-		runDMLTest(10, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+		runDMLTest(3, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
 	}
 
 	@Test
 	public void testParamservASPBatch() {
-		runDMLTest(10, 2, Statement.PSUpdateType.ASP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+		runDMLTest(3, 2, Statement.PSUpdateType.ASP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
 	}
 
 	@Test
 	public void testParamservBSPEpoch() {
-		runDMLTest(10, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+		runDMLTest(3, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
 	}
 
 	@Test
 	public void testParamservASPEpoch() {
-		runDMLTest(10, 2, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+		runDMLTest(3, 2, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
 	}
 
 	@Test
 	public void testParamservBSPBatchDisjointRoundRobin() {
-		runDMLTest(10, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_ROUND_ROBIN);
+		runDMLTest(3, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_ROUND_ROBIN);
 	}
 
 	@Test
 	public void testParamservBSPBatchDisjointRandom() {
-		runDMLTest(10, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_RANDOM);
+		runDMLTest(3, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_RANDOM);
 	}
 
 	@Test
 	public void testParamservBSPBatchOverlapReshuffle() {
-		runDMLTest(10, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.OVERLAP_RESHUFFLE);
+		runDMLTest(3, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.OVERLAP_RESHUFFLE);
 	}
 
 	private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme) {
diff --git a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTest.java b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTestwithNbatches.java
similarity index 54%
copy from src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTest.java
copy to src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTestwithNbatches.java
index 0e2f957..b0d21c6 100644
--- a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNTestwithNbatches.java
@@ -19,18 +19,18 @@
 
 package org.apache.sysds.test.functions.paramserv;
 
-import org.junit.Test;
 import org.apache.sysds.parser.Statement;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
+import org.junit.Test;
 
 @net.jcip.annotations.NotThreadSafe
-public class ParamservLocalNNTest extends AutomatedTestBase {
+public class ParamservLocalNNTestwithNbatches extends AutomatedTestBase {
 
-	private static final String TEST_NAME = "paramserv-test";
+	private static final String TEST_NAME = "paramserv-nbatches-test";
 
 	private static final String TEST_DIR = "functions/paramserv/";
-	private static final String TEST_CLASS_DIR = TEST_DIR + ParamservLocalNNTest.class.getSimpleName() + "/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + ParamservLocalNNTestwithNbatches.class.getSimpleName() + "/";
 
 	@Override
 	public void setUp() {
@@ -38,48 +38,28 @@ public class ParamservLocalNNTest extends AutomatedTestBase {
 	}
 
 	@Test
-	public void testParamservBSPBatchDisjointContiguous() {
-		runDMLTest(10, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
-	}
-
-	@Test
-	public void testParamservASPBatch() {
-		runDMLTest(10, 2, Statement.PSUpdateType.ASP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
-	}
-
-	@Test
-	public void testParamservBSPEpoch() {
-		runDMLTest(10, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
-	}
-
-	@Test
-	public void testParamservASPEpoch() {
-		runDMLTest(10, 2, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
-	}
-
-	@Test
-	public void testParamservBSPBatchDisjointRoundRobin() {
-		runDMLTest(10, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_ROUND_ROBIN);
+	public void testParamservBSPNBatchesDisjointContiguous() {
+		runDMLTest(3, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.NBATCHES, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS, 8, false );
 	}
 
 	@Test
-	public void testParamservBSPBatchDisjointRandom() {
-		runDMLTest(10, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_RANDOM);
+	public void testParamservBSPNBatchesDisjointContiguousModelAvg() {
+		runDMLTest(3, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.NBATCHES, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS, 8, true );
 	}
 
 	@Test
-	public void testParamservBSPBatchOverlapReshuffle() {
-		runDMLTest(10, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.OVERLAP_RESHUFFLE);
+	public void testParamservASPNBatches() {
+		runDMLTest(3, 2, Statement.PSUpdateType.ASP, Statement.PSFrequency.NBATCHES, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS, 8, false);
 	}
 
-	private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme) {
-		TestConfiguration config = getTestConfiguration(ParamservLocalNNTest.TEST_NAME);
+	private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme, int nbatches, boolean modelAvg) {
+		TestConfiguration config = getTestConfiguration(ParamservLocalNNTestwithNbatches.TEST_NAME);
 		loadTestConfiguration(config);
 		programArgs = new String[] { "-stats", "-nvargs", "mode=LOCAL", "epochs=" + epochs,
 			"workers=" + workers, "utype=" + utype, "freq=" + freq, "batchsize=" + batchsize,
-			"scheme=" + scheme };
+			"scheme=" + scheme, "nbatches=" + nbatches,  "modelAvg=" +modelAvg };
 		String HOME = SCRIPT_DIR + TEST_DIR;
-		fullDMLScriptName = HOME + ParamservLocalNNTest.TEST_NAME + ".dml";
+		fullDMLScriptName = HOME + ParamservLocalNNTestwithNbatches.TEST_NAME + ".dml";
 		runTest(true, false, null, null, -1);
 	}
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTestwithNbatches.java b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTestwithNbatches.java
new file mode 100644
index 0000000..11e40a1
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTestwithNbatches.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.paramserv;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.junit.Ignore;
+import org.junit.Test;
+
+@net.jcip.annotations.NotThreadSafe
+@Ignore
+public class ParamservSparkNNTestwithNbatches extends AutomatedTestBase {
+
+	private static final String TEST_NAME1 = "paramserv-nbatches-test";
+
+	private static final String TEST_DIR = "functions/paramserv/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + ParamservSparkNNTestwithNbatches.class.getSimpleName() + "/";
+
+	@Override
+	public void setUp() {
+		addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {}));
+	}
+
+	@Test
+	public void testParamservBSPNbatchesDisjointContiguous() {
+		runDMLTest(2, 2, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS, 16, false);
+	}
+
+	@Test
+	public void testParamservASPNbatchesDisjointContiguous() {
+		runDMLTest(2, 2, Statement.PSUpdateType.ASP, Statement.PSFrequency.BATCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS, 16, false);
+	}
+
+	private void internalRunDMLTest(String testname, boolean exceptionExpected, Class<?> expectedException,
+		String errMessage) {
+		ExecMode oldRtplatform = AutomatedTestBase.rtplatform;
+		boolean oldUseLocalSparkConfig = DMLScript.USE_LOCAL_SPARK_CONFIG;
+		AutomatedTestBase.rtplatform = ExecMode.HYBRID;
+		DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+		try {
+			TestConfiguration config = getTestConfiguration(testname);
+			loadTestConfiguration(config);
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + testname + ".dml";
+			runTest(true, exceptionExpected, expectedException, errMessage, -1);
+		} finally {
+			AutomatedTestBase.rtplatform = oldRtplatform;
+			DMLScript.USE_LOCAL_SPARK_CONFIG = oldUseLocalSparkConfig;
+		}
+	}
+
+	private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme, int nbatches, boolean modelAvg) {
+		programArgs = new String[] { "-nvargs", "mode=REMOTE_SPARK", "epochs=" + epochs, "workers=" + workers, "utype=" + utype, "freq=" + freq, "batchsize=" + batchsize, "scheme=" + scheme + "nbatches=" + nbatches, "modelAvg=" + modelAvg};
+		internalRunDMLTest(TEST_NAME1, false, null, null);
+	}
+}
diff --git a/src/test/scripts/functions/federated/paramserv/CNNwithNbatches.dml b/src/test/scripts/functions/federated/paramserv/CNNwithNbatches.dml
new file mode 100644
index 0000000..dae65fc
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/CNNwithNbatches.dml
@@ -0,0 +1,471 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * This file implements all needed functions to evaluate a convolutional neural network of the "LeNet" architecture
+ * on different execution schemes and with different inputs, for example a federated input matrix.
+ */
+
+# Imports
+source("scripts/nn/layers/affine.dml") as affine
+source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("scripts/nn/layers/dropout.dml") as dropout
+source("scripts/nn/layers/l2_reg.dml") as l2_reg
+source("scripts/nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("scripts/nn/layers/relu.dml") as relu
+source("scripts/nn/layers/softmax.dml") as softmax
+source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+/*
+ * Trains a convolutional net using the "LeNet" architectur single threaded the conventional way.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.  The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix, of shape (N, C*Hin*Win)
+ *  - y: Target matrix, of shape (N, K)
+ *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win)
+ *  - y_val: Target validation matrix, of shape (N, K)
+ *  - C: Number of input channels (dimensionality of input depth)
+ *  - Hin: Input height
+ *  - Win: Input width
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ *       - b1: 1st layer biases vector, of shape (F1, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ *       - b2: 2nd layer biases vector, of shape (F2, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3)
+ *       - b3: 3rd layer biases vector, of shape (1, N3)
+ *       - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ *       - b4: 4th layer biases vector, of shape (1, K)
+ */
+train = function(matrix[double] X, matrix[double] y, matrix[double] X_val,
+  matrix[double] y_val, int epochs, int batch_size, double eta, int C, int Hin,
+	int Win, int seed = -1, int nbatches) return (list[unknown] model)
+{
+  N = nrow(X)
+  K = ncol(y)
+
+  # Create network:
+  ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = seed)  # inputs: (N, C*Hin*Win)
+  lseed = ifelse(seed==-1, -1, seed + 1);
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf, seed = lseed)  # inputs: (N, F1*(Hin/2)*(Win/2))
+  lseed = ifelse(seed==-1, -1, seed + 2);
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, seed = lseed)  # inputs: (N, F2*(Hin/2/2)*(Win/2/2))
+  lseed = ifelse(seed==-1, -1, seed + 3);
+  [W4, b4] = affine::init(N3, K, seed = lseed)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  mu = 0.9  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+  model = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
+
+  # Regularization
+  lambda = 5e-04
+
+  # Create the hyper parameter list
+  hyperparams = list(learning_rate=eta, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+  # Calculate iterations
+  iters = ceil(N / batch_size)
+
+  for (e in 1:epochs) {
+    for(i in 1:iters) {
+      # Get next batch
+      beg = ((i-1) * batch_size) %% N + 1
+      end = min(N, beg + batch_size - 1)
+      X_batch = X[beg:end,]
+      y_batch = y[beg:end,]
+
+      gradients_list = gradients(model, hyperparams, X_batch, y_batch)
+      model = aggregation(model, hyperparams, gradients_list)
+    }
+  }
+}
+
+/*
+ * Trains a convolutional net using the "LeNet" architecture using a parameter server with specified properties.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.  The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix, of shape (N, C*Hin*Win)
+ *  - Y: Target matrix, of shape (N, K)
+ *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win)
+ *  - Y_val: Target validation matrix, of shape (N, K)
+ *  - C: Number of input channels (dimensionality of input depth)
+ *  - Hin: Input height
+ *  - Win: Input width
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *  - workers: Number of workers to create
+ *  - utype: parameter server framework to use
+ *  - scheme: update schema
+ *  - mode: local or distributed
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ *       - b1: 1st layer biases vector, of shape (F1, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ *       - b2: 2nd layer biases vector, of shape (F2, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3)
+ *       - b3: 3rd layer biases vector, of shape (1, N3)
+ *       - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ *       - b4: 4th layer biases vector, of shape (1, K)
+ */
+train_paramserv = function(matrix[double] X, matrix[double] y,
+  matrix[double] X_val, matrix[double] y_val, int num_workers, int epochs,
+  string utype, string freq, int batch_size, string scheme, string runtime_balancing,
+  string weighting, double eta, int C, int Hin, int Win, int seed = -1, int nbatches)
+  return (list[unknown] model)
+{
+  N = nrow(X)
+  K = ncol(y)
+
+  # Create network:
+  ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = seed)  # inputs: (N, C*Hin*Win)
+  lseed = ifelse(seed==-1, -1, seed + 1);
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf, seed = lseed)  # inputs: (N, F1*(Hin/2)*(Win/2))
+  lseed = ifelse(seed==-1, -1, seed + 2);
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, seed = lseed)  # inputs: (N, F2*(Hin/2/2)*(Win/2/2))
+  lseed = ifelse(seed==-1, -1, seed + 3);
+  [W4, b4] = affine::init(N3, K, seed = lseed)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  learning_rate = eta  # learning rate
+  mu = 0.9  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+  # Regularization
+  lambda = 5e-04
+  # Create the model list
+  model = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
+  # Create the hyper parameter list
+  hyperparams = list(learning_rate=eta, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+
+  # Use paramserv function
+  model = paramserv(model=model, features=X, labels=y, val_features=X_val, val_labels=y_val,
+    upd="./src/test/scripts/functions/federated/paramserv/CNNwithNbatches.dml::gradients",
+    agg="./src/test/scripts/functions/federated/paramserv/CNNwithNbatches.dml::aggregation",
+    val="./src/test/scripts/functions/federated/paramserv/CNNwithNbatches.dml::validate",
+    k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
+    scheme=scheme, runtime_balancing=runtime_balancing, weighting=weighting, hyperparams=hyperparams, seed=seed, nbatches=nbatches)
+}
+
+/*
+ * Computes the class probability predictions of a convolutional
+ * net using the "LeNet" architecture.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.
+ *
+ * Inputs:
+ *  - X: Input data matrix, of shape (N, C*Hin*Win)
+ *  - C: Number of input channels (dimensionality of input depth)
+ *  - Hin: Input height
+ *  - Win: Input width
+ *  - batch_size: Batch size
+ *  - model: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ *       - b1: 1st layer biases vector, of shape (F1, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ *       - b2: 2nd layer biases vector, of shape (F2, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3)
+ *       - b3: 3rd layer biases vector, of shape (1, N3)
+ *       - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ *       - b4: 4th layer biases vector, of shape (1, K)
+ *
+ * Outputs:
+ *  - probs: Class probabilities, of shape (N, K)
+ */
+predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size, list[unknown] model)
+    return (matrix[double] probs) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  W4 = as.matrix(model[4])
+  b1 = as.matrix(model[5])
+  b2 = as.matrix(model[6])
+  b3 = as.matrix(model[7])
+  b4 = as.matrix(model[8])
+  N = nrow(X)
+
+  # Network:
+  ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+  F1 = nrow(W1)  # num conv filters in conv1
+  F2 = nrow(W2)  # num conv filters in conv2
+  N3 = ncol(W3)  # num nodes in affine3
+  K = ncol(W4)  # num nodes in affine4, equal to number of target dimensions (num classes)
+
+  # Compute predictions over mini-batches
+  probs = matrix(0, rows=N, cols=K)
+  iters = ceil(N / batch_size)
+  for(i in 1:iters, check=0) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1
+    end = min(N, beg + batch_size - 1)
+    X_batch = X[beg:end,]
+
+    # Compute forward pass
+    ## layer 1: conv1 -> relu1 -> pool1
+    [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, Hf, Wf, stride, stride,
+                                              pad, pad)
+    outr1 = relu::forward(outc1)
+    [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0)
+    ## layer 2: conv2 -> relu2 -> pool2
+    [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+                                              stride, stride, pad, pad)
+    outr2 = relu::forward(outc2)
+    [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2, 2, 2, 2, 0, 0)
+    ## layer 3:  affine3 -> relu3
+    outa3 = affine::forward(outp2, W3, b3)
+    outr3 = relu::forward(outa3)
+    ## layer 4:  affine4 -> softmax
+    outa4 = affine::forward(outr3, W4, b4)
+    probs_batch = softmax::forward(outa4)
+
+    # Store predictions
+    probs[beg:end,] = probs_batch
+  }
+}
+
+/*
+ * Evaluates a convolutional net using the "LeNet" architecture.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples.  The targets, y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ *  - probs: Class probabilities, of shape (N, K)
+ *  - y: Target matrix, of shape (N, K)
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1)
+ *  - accuracy: Scalar accuracy, of shape (1)
+ */
+eval = function(matrix[double] probs, matrix[double] y)
+    return (double loss, double accuracy) {
+
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+  accuracy = mean(correct_pred)
+}
+
+/*
+ * Gives the accuracy and loss for a model and given feature and label matrices
+ *
+ * This function is a combination of the predict and eval function used for validation.
+ * For inputs see eval and predict.
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1).
+ *  - accuracy: Scalar accuracy, of shape (1).
+ */
+validate = function(matrix[double] val_features, matrix[double] val_labels, 
+  list[unknown] model, list[unknown] hyperparams) 
+	return (double loss, double accuracy)
+{
+  [loss, accuracy] = eval(predict(val_features, as.integer(as.scalar(hyperparams["C"])),
+    as.integer(as.scalar(hyperparams["Hin"])), as.integer(as.scalar(hyperparams["Win"])), 
+    32, model), val_labels)
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+                     list[unknown] hyperparams,
+                     matrix[double] features,
+                     matrix[double] labels)
+          return (list[unknown] gradients) {
+
+  C = as.integer(as.scalar(hyperparams["C"]))
+  Hin = as.integer(as.scalar(hyperparams["Hin"]))
+  Win = as.integer(as.scalar(hyperparams["Win"]))
+  Hf = as.integer(as.scalar(hyperparams["Hf"]))
+  Wf = as.integer(as.scalar(hyperparams["Wf"]))
+  stride = as.integer(as.scalar(hyperparams["stride"]))
+  pad = as.integer(as.scalar(hyperparams["pad"]))
+  lambda = as.double(as.scalar(hyperparams["lambda"]))
+  F1 = as.integer(as.scalar(hyperparams["F1"]))
+  F2 = as.integer(as.scalar(hyperparams["F2"]))
+  N3 = as.integer(as.scalar(hyperparams["N3"]))
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  W4 = as.matrix(model[4])
+  b1 = as.matrix(model[5])
+  b2 = as.matrix(model[6])
+  b3 = as.matrix(model[7])
+  b4 = as.matrix(model[8])
+
+  # Compute forward pass
+  ## layer 1: conv1 -> relu1 -> pool1
+  [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, Wf,
+                                              stride, stride, pad, pad)
+  outr1 = relu::forward(outc1)
+  [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0)
+  ## layer 2: conv2 -> relu2 -> pool2
+  [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+                                            stride, stride, pad, pad)
+  outr2 = relu::forward(outc2)
+  [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2, 2, 2, 2, 0, 0)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  outa3 = affine::forward(outp2, W3, b3)
+  outr3 = relu::forward(outa3)
+  [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+  ## layer 4:  affine4 -> softmax
+  outa4 = affine::forward(outd3, W4, b4)
+  probs = softmax::forward(outa4)
+
+  # Compute loss & accuracy for training data
+  loss = cross_entropy_loss::forward(probs, labels)
+  accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
+  # print("[+] Completed forward pass on batch: train loss: " + loss + ", train accuracy: " + accuracy)
+
+  # Compute data backward pass
+  ## loss
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  ## layer 4:  affine4 -> softmax
+  douta4 = softmax::backward(dprobs, outa4)
+  [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+  douta3 = relu::backward(doutr3, outa3)
+  [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+  ## layer 2: conv2 -> relu2 -> pool2
+  doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, Woutc2, 2, 2, 2, 2, 0, 0)
+  doutc2 = relu::backward(doutr2, outc2)
+  [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, F1,
+                                        Houtp1, Woutp1, Hf, Wf, stride, stride, pad, pad)
+  ## layer 1: conv1 -> relu1 -> pool1
+  doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0)
+  doutc1 = relu::backward(doutr1, outc1)
+  [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, W1, b1, C, Hin, Win,
+                                          Hf, Wf, stride, stride, pad, pad)
+
+  # Compute regularization backward pass
+  dW1_reg = l2_reg::backward(W1, lambda)
+  dW2_reg = l2_reg::backward(W2, lambda)
+  dW3_reg = l2_reg::backward(W3, lambda)
+  dW4_reg = l2_reg::backward(W4, lambda)
+  dW1 = dW1 + dW1_reg
+  dW2 = dW2 + dW2_reg
+  dW3 = dW3 + dW3_reg
+  dW4 = dW4 + dW4_reg
+
+  gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
+    return (list[unknown] model_result) {
+
+   W1 = as.matrix(model[1])
+   W2 = as.matrix(model[2])
+   W3 = as.matrix(model[3])
+   W4 = as.matrix(model[4])
+   b1 = as.matrix(model[5])
+   b2 = as.matrix(model[6])
+   b3 = as.matrix(model[7])
+   b4 = as.matrix(model[8])
+   dW1 = as.matrix(gradients[1])
+   dW2 = as.matrix(gradients[2])
+   dW3 = as.matrix(gradients[3])
+   dW4 = as.matrix(gradients[4])
+   db1 = as.matrix(gradients[5])
+   db2 = as.matrix(gradients[6])
+   db3 = as.matrix(gradients[7])
+   db4 = as.matrix(gradients[8])
+   vW1 = as.matrix(model[9])
+   vW2 = as.matrix(model[10])
+   vW3 = as.matrix(model[11])
+   vW4 = as.matrix(model[12])
+   vb1 = as.matrix(model[13])
+   vb2 = as.matrix(model[14])
+   vb3 = as.matrix(model[15])
+   vb4 = as.matrix(model[16])
+   learning_rate = as.double(as.scalar(hyperparams["learning_rate"]))
+   mu = as.double(as.scalar(hyperparams["mu"]))
+
+   # Optimize with SGD w/ Nesterov momentum
+   [W1, vW1] = sgd_nesterov::update(W1, dW1, learning_rate, mu, vW1)
+   [b1, vb1] = sgd_nesterov::update(b1, db1, learning_rate, mu, vb1)
+   [W2, vW2] = sgd_nesterov::update(W2, dW2, learning_rate, mu, vW2)
+   [b2, vb2] = sgd_nesterov::update(b2, db2, learning_rate, mu, vb2)
+   [W3, vW3] = sgd_nesterov::update(W3, dW3, learning_rate, mu, vW3)
+   [b3, vb3] = sgd_nesterov::update(b3, db3, learning_rate, mu, vb3)
+   [W4, vW4] = sgd_nesterov::update(W4, dW4, learning_rate, mu, vW4)
+   [b4, vb4] = sgd_nesterov::update(b4, db4, learning_rate, mu, vb4)
+
+   model_result = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
+}
diff --git a/src/test/scripts/functions/federated/paramserv/FederatedParamservTestwithNbatches.dml b/src/test/scripts/functions/federated/paramserv/FederatedParamservTestwithNbatches.dml
new file mode 100644
index 0000000..d55b716
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTestwithNbatches.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+source("src/test/scripts/functions/federated/paramserv/TwoNNwithNbatches.dml") as TwoNNwithNbatches
+source("src/test/scripts/functions/federated/paramserv/CNNwithNbatches.dml") as CNNwithNbatches
+
+# create federated input matrices
+features = read($features)
+labels = read($labels)
+
+if($network_type == "TwoNN") {
+  model = TwoNNwithNbatches::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $seed, $nbatches)
+  print("Test results:")
+  [loss_test, accuracy_test] = TwoNNwithNbatches::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, list())
+  print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n")
+}
+else {
+  model = CNNwithNbatches::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin, $win, $seed, $nbatches)
+  print("Test results:")
+  hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win)
+  [loss_test, accuracy_test] = CNNwithNbatches::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, hyperparams)
+  print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n")
+}
diff --git a/src/test/scripts/functions/federated/paramserv/NbatchesFederatedParamservTest.dml b/src/test/scripts/functions/federated/paramserv/NbatchesFederatedParamservTest.dml
new file mode 100644
index 0000000..d55b716
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/NbatchesFederatedParamservTest.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+source("src/test/scripts/functions/federated/paramserv/TwoNNwithNbatches.dml") as TwoNNwithNbatches
+source("src/test/scripts/functions/federated/paramserv/CNNwithNbatches.dml") as CNNwithNbatches
+
+# create federated input matrices
+features = read($features)
+labels = read($labels)
+
+if($network_type == "TwoNN") {
+  model = TwoNNwithNbatches::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $seed, $nbatches)
+  print("Test results:")
+  [loss_test, accuracy_test] = TwoNNwithNbatches::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, list())
+  print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n")
+}
+else {
+  model = CNNwithNbatches::train_paramserv(features, labels, matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin, $win, $seed, $nbatches)
+  print("Test results:")
+  hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win)
+  [loss_test, accuracy_test] = CNNwithNbatches::validate(matrix(0, rows=100, cols=784), matrix(0, rows=100, cols=10), model, hyperparams)
+  print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + "\n")
+}
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNNwithNbatches.dml b/src/test/scripts/functions/federated/paramserv/TwoNNwithNbatches.dml
new file mode 100644
index 0000000..e562867
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/TwoNNwithNbatches.dml
@@ -0,0 +1,305 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * This file implements all needed functions to evaluate a simple feed forward neural network
+ * on different execution schemes and with different inputs, for example a federated input matrix.
+ */
+
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd.dml") as sgd
+
+/*
+ * Trains a simple feed forward neural network with two hidden layers single threaded the conventional way.
+ *
+ * The input matrix has one example per row (N) and D features.
+ * The targets, y, have K classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix of shape (N, D)
+ *  - y: Target matrix of shape (N, K)
+ *  - X_val: Input validation data matrix of shape (N_val, D)
+ *  - y_val: Targed validation matrix of shape (N_val, K)
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ *       - b1: 1st layer biases vector, of shape (200, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ *       - b2: 2nd layer biases vector, of shape (200, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ *       - b3: 3rd layer biases vector, of shape (K, 1)
+ */
+train = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int epochs, int batch_size, double eta,
+                 int seed = -1, int nbatches)
+    return (list[unknown] model) {
+
+  N = nrow(X)  # num examples
+  D = ncol(X)  # num features
+  K = ncol(y)  # num classes
+
+  # Create the network:
+  ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+  [W1, b1] = affine::init(D, 200, seed = seed)
+  lseed = ifelse(seed==-1, -1, seed + 1);
+  [W2, b2] = affine::init(200, 200,  seed = lseed)
+  lseed = ifelse(seed==-1, -1, seed + 2);
+  [W3, b3] = affine::init(200, K, seed = lseed)
+  W3 = W3 / sqrt(2)  # different initialization, since being fed into softmax, instead of relu
+  model = list(W1, W2, W3, b1, b2, b3)
+
+  # Create the hyper parameter list
+  hyperparams = list(learning_rate=eta)
+  # Calculate iterations
+  iters = ceil(N / batch_size)
+
+  for (e in 1:epochs) {
+    for(i in 1:iters) {
+      # Get next batch
+      beg = ((i-1) * batch_size) %% N + 1
+      end = min(N, beg + batch_size - 1)
+      X_batch = X[beg:end,]
+      y_batch = y[beg:end,]
+
+      gradients_list = gradients(model, hyperparams, X_batch, y_batch)
+      model = aggregation(model, hyperparams, gradients_list)
+    }
+  }
+}
+
+/*
+ * Trains a simple feed forward neural network with two hidden layers
+ * using a parameter server with specified properties.
+ *
+ * The input matrix has one example per row (N) and D features.
+ * The targets, y, have K classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix of shape (N, D)
+ *  - y: Target matrix of shape (N, K)
+ *  - X_val: Input validation data matrix of shape (N_val, D)
+ *  - y_val: Targed validation matrix of shape (N_val, K)
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *  - workers: Number of workers to create
+ *  - utype: parameter server framework to use
+ *  - scheme: update schema
+ *  - mode: local or distributed
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ *       - b1: 1st layer biases vector, of shape (200, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ *       - b2: 2nd layer biases vector, of shape (200, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ *       - b3: 3rd layer biases vector, of shape (K, 1)
+ */
+train_paramserv = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, string weighting,
+                 double eta, int seed = -1, int nbatches)
+    return (list[unknown] model) {
+
+  N = nrow(X)  # num examples
+  D = ncol(X)  # num features
+  K = ncol(y)  # num classes
+
+  # Create the network:
+  ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+  [W1, b1] = affine::init(D, 200, seed = seed)
+  lseed = ifelse(seed==-1, -1, seed + 1);
+  [W2, b2] = affine::init(200, 200,  seed = lseed)
+  lseed = ifelse(seed==-1, -1, seed + 2);
+  [W3, b3] = affine::init(200, K, seed = lseed)
+  # W3 = W3 / sqrt(2) # different initialization, since being fed into softmax, instead of relu
+
+  # [W1, b1] = affine::init(D, 200)
+  # [W2, b2] = affine::init(200, 200)
+  # [W3, b3] = affine::init(200, K)
+
+  # Create the model list
+  model = list(W1, W2, W3, b1, b2, b3)
+  # Create the hyper parameter list
+  hyperparams = list(learning_rate=eta)
+  # Use paramserv function
+  model = paramserv(model=model, features=X, labels=y, val_features=X_val, val_labels=y_val,
+    upd="./src/test/scripts/functions/federated/paramserv/TwoNNwithNbatches.dml::gradients",
+    agg="./src/test/scripts/functions/federated/paramserv/TwoNNwithNbatches.dml::aggregation",
+    val="./src/test/scripts/functions/federated/paramserv/TwoNNwithNbatches.dml::validate",
+    k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
+    scheme=scheme, runtime_balancing=runtime_balancing, weighting=weighting, hyperparams=hyperparams, seed=seed, nbatches=nbatches)
+}
+
+/*
+ * Computes the class probability predictions of a simple feed forward neural network.
+ *
+ * Inputs:
+ *  - X: The input data matrix of shape (N, D)
+ *  - model: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ *       - b1: 1st layer biases vector, of shape (200, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ *       - b2: 2nd layer biases vector, of shape (200, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ *       - b3: 3rd layer biases vector, of shape (K, 1)
+ *
+ * Outputs:
+ *  - probs: Class probabilities, of shape (N, K)
+ */
+predict = function(matrix[double] X,
+                   list[unknown] model)
+    return (matrix[double] probs) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  b1 = as.matrix(model[4])
+  b2 = as.matrix(model[5])
+  b3 = as.matrix(model[6])
+
+  out1relu = relu::forward(affine::forward(X, W1, b1))
+  out2relu = relu::forward(affine::forward(out1relu, W2, b2))
+  probs = softmax::forward(affine::forward(out2relu, W3, b3))
+}
+
+/*
+ * Evaluates a simple feed forward neural network.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples.  The targets, y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ *  - probs: Class probabilities, of shape (N, K).
+ *  - y: Target matrix, of shape (N, K).
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1).
+ *  - accuracy: Scalar accuracy, of shape (1).
+ */
+eval = function(matrix[double] probs, matrix[double] y)
+    return (double loss, double accuracy) {
+
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+  accuracy = mean(correct_pred)
+}
+
+/*
+ * Gives the accuracy and loss for a model and given feature and label matrices
+ *
+ * This function is a combination of the predict and eval function used for validation.
+ * For inputs see eval and predict.
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1).
+ *  - accuracy: Scalar accuracy, of shape (1).
+ */
+validate = function(matrix[double] val_features, matrix[double] val_labels, list[unknown] model, list[unknown] hyperparams)
+    return (double loss, double accuracy) {
+  [loss, accuracy] = eval(predict(val_features, model), val_labels)
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+                     list[unknown] hyperparams,
+                     matrix[double] features,
+                     matrix[double] labels)
+    return (list[unknown] gradients) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  b1 = as.matrix(model[4])
+  b2 = as.matrix(model[5])
+  b3 = as.matrix(model[6])
+
+  # Compute forward pass
+  ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+  out1 = affine::forward(features, W1, b1)
+  out1relu = relu::forward(out1)
+  out2 = affine::forward(out1relu, W2, b2)
+  out2relu = relu::forward(out2)
+  out3 = affine::forward(out2relu, W3, b3)
+  probs = softmax::forward(out3)
+
+  # Compute loss & accuracy for training data
+  loss = cross_entropy_loss::forward(probs, labels)
+  accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
+  # print("[+] Completed forward pass on batch: train loss: " + loss + ", train accuracy: " + accuracy)
+
+  # Compute data backward pass
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  dout3 = softmax::backward(dprobs, out3)
+  [dout2relu, dW3, db3] = affine::backward(dout3, out2relu, W3, b3)
+  dout2 = relu::backward(dout2relu, out2)
+  [dout1relu, dW2, db2] = affine::backward(dout2, out1relu, W2, b2)
+  dout1 = relu::backward(dout1relu, out1)
+  [dfeatures, dW1, db1] = affine::backward(dout1, features, W1, b1)
+
+  gradients = list(dW1, dW2, dW3, db1, db2, db3)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
+    return (list[unknown] model_result) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  b1 = as.matrix(model[4])
+  b2 = as.matrix(model[5])
+  b3 = as.matrix(model[6])
+  dW1 = as.matrix(gradients[1])
+  dW2 = as.matrix(gradients[2])
+  dW3 = as.matrix(gradients[3])
+  db1 = as.matrix(gradients[4])
+  db2 = as.matrix(gradients[5])
+  db3 = as.matrix(gradients[6])
+  learning_rate = as.double(as.scalar(hyperparams["learning_rate"]))
+
+  # Optimize with SGD
+  W3 = sgd::update(W3, dW3, learning_rate)
+  b3 = sgd::update(b3, db3, learning_rate)
+  W2 = sgd::update(W2, dW2, learning_rate)
+  b2 = sgd::update(b2, db2, learning_rate)
+  W1 = sgd::update(W1, dW1, learning_rate)
+  b1 = sgd::update(b1, db1, learning_rate)
+
+  model_result = list(W1, W2, W3, b1, b2, b3)
+}
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_nbatches.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_nbatches.dml
new file mode 100644
index 0000000..52de2fb
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_nbatches.dml
@@ -0,0 +1,372 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * MNIST LeNet Example
+ */
+# Imports
+source("scripts/nn/layers/affine.dml") as affine
+source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("scripts/nn/layers/dropout.dml") as dropout
+source("scripts/nn/layers/l2_reg.dml") as l2_reg
+source("scripts/nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("scripts/nn/layers/relu.dml") as relu
+source("scripts/nn/layers/softmax.dml") as softmax
+source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+train = function(matrix[double] X, matrix[double] Y,
+                 matrix[double] X_val, matrix[double] Y_val,
+                 int C, int Hin, int Win, int epochs, int workers,
+                 string utype, string freq, int batchsize, string scheme, string mode, int nbatches, boolean modelAvg)
+    return (matrix[double] W1, matrix[double] b1,
+            matrix[double] W2, matrix[double] b2,
+            matrix[double] W3, matrix[double] b3,
+            matrix[double] W4, matrix[double] b4) {
+  /*
+   * Trains a convolutional net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.  The targets, Y, have K
+   * classes, and are one-hot encoded.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+   *  - Y_val: Target validation matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - epochs: Total number of full training loops over the full data set.
+   *
+   * Outputs:
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   */
+  N = nrow(X)
+  K = ncol(Y)
+
+  # Create network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf, -1)  # inputs: (N, C*Hin*Win)
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf, -1)  # inputs: (N, F1*(Hin/2)*(Win/2))
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, -1)  # inputs: (N, F2*(Hin/2/2)*(Win/2/2))
+  [W4, b4] = affine::init(N3, K, -1)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  lr = 0.01  # learning rate
+  mu = 0.9  #0.5  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+  # Regularization
+  lambda = 5e-04
+
+  # Create the model list
+  modelList = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
+
+  # Create the hyper parameter list
+  params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+
+  # Use paramserv function
+  modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation", mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batchsize, k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE", nbatches=nbatches, modelAvg=modelAvg)
+
+  W1 = as.matrix(modelList2[1])
+  W2 = as.matrix(modelList2[2])
+  W3 = as.matrix(modelList2[3])
+  W4 = as.matrix(modelList2[4])
+  b1 = as.matrix(modelList2[5])
+  b2 = as.matrix(modelList2[6])
+  b3 = as.matrix(modelList2[7])
+  b4 = as.matrix(modelList2[8])
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+                     list[unknown] hyperparams,
+                     matrix[double] features,
+                     matrix[double] labels)
+          return (list[unknown] gradients) {
+
+  C = as.integer(as.scalar(hyperparams["C"]))
+  Hin = as.integer(as.scalar(hyperparams["Hin"]))
+  Win = as.integer(as.scalar(hyperparams["Win"]))
+  Hf = as.integer(as.scalar(hyperparams["Hf"]))
+  Wf = as.integer(as.scalar(hyperparams["Wf"]))
+  stride = as.integer(as.scalar(hyperparams["stride"]))
+  pad = as.integer(as.scalar(hyperparams["pad"]))
+  lambda = as.double(as.scalar(hyperparams["lambda"]))
+  F1 = as.integer(as.scalar(hyperparams["F1"]))
+  F2 = as.integer(as.scalar(hyperparams["F2"]))
+  N3 = as.integer(as.scalar(hyperparams["N3"]))
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  W4 = as.matrix(model[4])
+  b1 = as.matrix(model[5])
+  b2 = as.matrix(model[6])
+  b3 = as.matrix(model[7])
+  b4 = as.matrix(model[8])
+
+  # Compute forward pass
+  ## layer 1: conv1 -> relu1 -> pool1
+  [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, Wf,
+                                              stride, stride, pad, pad)
+  outr1 = relu::forward(outc1)
+  [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0)
+  ## layer 2: conv2 -> relu2 -> pool2
+  [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+                                            stride, stride, pad, pad)
+  outr2 = relu::forward(outc2)
+  [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2, 2, 2, 2, 0, 0)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  outa3 = affine::forward(outp2, W3, b3)
+  outr3 = relu::forward(outa3)
+  [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+  ## layer 4:  affine4 -> softmax
+  outa4 = affine::forward(outd3, W4, b4)
+  probs = softmax::forward(outa4)
+
+  # Compute data backward pass
+  ## loss:
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  ## layer 4:  affine4 -> softmax
+  douta4 = softmax::backward(dprobs, outa4)
+  [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+  douta3 = relu::backward(doutr3, outa3)
+  [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+  ## layer 2: conv2 -> relu2 -> pool2
+  doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, Woutc2, 2, 2, 2, 2, 0, 0)
+  doutc2 = relu::backward(doutr2, outc2)
+  [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, F1,
+                                        Houtp1, Woutp1, Hf, Wf, stride, stride, pad, pad)
+  ## layer 1: conv1 -> relu1 -> pool1
+  doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0)
+  doutc1 = relu::backward(doutr1, outc1)
+  [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, W1, b1, C, Hin, Win,
+                                          Hf, Wf, stride, stride, pad, pad)
+
+  # Compute regularization backward pass
+  dW1_reg = l2_reg::backward(W1, lambda)
+  dW2_reg = l2_reg::backward(W2, lambda)
+  dW3_reg = l2_reg::backward(W3, lambda)
+  dW4_reg = l2_reg::backward(W4, lambda)
+  dW1 = dW1 + dW1_reg
+  dW2 = dW2 + dW2_reg
+  dW3 = dW3 + dW3_reg
+  dW4 = dW4 + dW4_reg
+
+  gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
+   return (list[unknown] modelResult) {
+     W1 = as.matrix(model[1])
+     W2 = as.matrix(model[2])
+     W3 = as.matrix(model[3])
+     W4 = as.matrix(model[4])
+     b1 = as.matrix(model[5])
+     b2 = as.matrix(model[6])
+     b3 = as.matrix(model[7])
+     b4 = as.matrix(model[8])
+     dW1 = as.matrix(gradients[1])
+     dW2 = as.matrix(gradients[2])
+     dW3 = as.matrix(gradients[3])
+     dW4 = as.matrix(gradients[4])
+     db1 = as.matrix(gradients[5])
+     db2 = as.matrix(gradients[6])
+     db3 = as.matrix(gradients[7])
+     db4 = as.matrix(gradients[8])
+     vW1 = as.matrix(model[9])
+     vW2 = as.matrix(model[10])
+     vW3 = as.matrix(model[11])
+     vW4 = as.matrix(model[12])
+     vb1 = as.matrix(model[13])
+     vb2 = as.matrix(model[14])
+     vb3 = as.matrix(model[15])
+     vb4 = as.matrix(model[16])
+     lr = as.double(as.scalar(hyperparams["lr"]))
+     mu = as.double(as.scalar(hyperparams["mu"]))
+
+     # Optimize with SGD w/ Nesterov momentum
+     [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+     [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+     [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+     [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+     [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+     [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+     [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
+     [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
+
+     modelResult = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
+   }
+
+predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size,
+                   matrix[double] W1, matrix[double] b1,
+                   matrix[double] W2, matrix[double] b2,
+                   matrix[double] W3, matrix[double] b3,
+                   matrix[double] W4, matrix[double] b4)
+    return (matrix[double] probs) {
+  /*
+   * Computes the class probability predictions of a convolutional
+   * net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   *
+   * Outputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   */
+  N = nrow(X)
+
+  # Network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = nrow(W1)  # num conv filters in conv1
+  F2 = nrow(W2)  # num conv filters in conv2
+  N3 = ncol(W3)  # num nodes in affine3
+  K = ncol(W4)  # num nodes in affine4, equal to number of target dimensions (num classes)
+
+  # Compute predictions over mini-batches
+  probs = matrix(0, rows=N, cols=K)
+  iters = ceil(N / batch_size)
+  parfor(i in 1:iters, check=0) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1
+    end = min(N, beg + batch_size - 1)
+    X_batch = X[beg:end,]
+
+    # Compute forward pass
+    ## layer 1: conv1 -> relu1 -> pool1
+    [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, Hf, Wf, stride, stride,
+                                              pad, pad)
+    outr1 = relu::forward(outc1)
+    [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0)
+    ## layer 2: conv2 -> relu2 -> pool2
+    [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+                                              stride, stride, pad, pad)
+    outr2 = relu::forward(outc2)
+    [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2, 2, 2, 2, 0, 0)
+    ## layer 3:  affine3 -> relu3
+    outa3 = affine::forward(outp2, W3, b3)
+    outr3 = relu::forward(outa3)
+    ## layer 4:  affine4 -> softmax
+    outa4 = affine::forward(outr3, W4, b4)
+    probs_batch = softmax::forward(outa4)
+
+    # Store predictions
+    probs[beg:end,] = probs_batch
+  }
+}
+
+eval = function(matrix[double] probs, matrix[double] Y)
+    return (double loss, double accuracy) {
+  /*
+   * Evaluates a convolutional net using the "LeNet" architecture.
+   *
+   * The probs matrix contains the class probability predictions
+   * of K classes over N examples.  The targets, Y, have K classes,
+   * and are one-hot encoded.
+   *
+   * Inputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   *  - Y: Target matrix, of shape (N, K).
+   *
+   * Outputs:
+   *  - loss: Scalar loss, of shape (1).
+   *  - accuracy: Scalar accuracy, of shape (1).
+   */
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, Y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(Y)
+  accuracy = mean(correct_pred)
+}
+
+generate_dummy_data = function()
+    return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) {
+  /*
+   * Generate a dummy dataset similar to the MNIST dataset.
+   *
+   * Outputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   */
+  # Generate dummy input data
+  N = 1024  # num examples
+  C = 1  # num input channels
+  Hin = 28  # input height
+  Win = 28  # input width
+  K = 10  # num target classes
+  X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
+  classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform"))
+  Y = table(seq(1, N), classes)  # one-hot encoding
+}
+
diff --git a/src/test/scripts/functions/paramserv/paramserv-nbatches-test.dml b/src/test/scripts/functions/paramserv/paramserv-nbatches-test.dml
new file mode 100644
index 0000000..c7e66bd
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-nbatches-test.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv_nbatches.dml") as mnist_lenet_paramserv_nbatches
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet_paramserv
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+
+# Generate the training data
+[images, labels, C, Hin, Win] =mnist_lenet_paramserv_nbatches::generate_dummy_data()
+n = nrow(images)
+
+# Generate the training data
+[X, Y, C, Hin, Win] = mnist_lenet_paramserv_nbatches::generate_dummy_data()
+
+# Split into training and validation
+val_size = n * 0.1
+X = images[(val_size+1):n,]
+X_val = images[1:val_size,]
+Y = labels[(val_size+1):n,]
+Y_val = labels[1:val_size,]
+
+# Train
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet_paramserv_nbatches::train(X, Y, X_val, Y_val, C, Hin, Win, $epochs, $workers, $utype, $freq, $batchsize, $scheme, $mode, $nbatches, $modelAvg)
+
+# Compute validation loss & accuracy
+probs_val = mnist_lenet_paramserv_nbatches::predict(X_val, C, Hin, Win, $batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
+loss_val = cross_entropy_loss::forward(probs_val, Y_val)
+accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
+
+# Output results
+print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)