You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2020/10/30 18:56:47 UTC

[systemds] branch master updated: [SYSTEMDS-2550] Federated Parameter Server

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 428016c  [SYSTEMDS-2550] Federated Parameter Server
428016c is described below

commit 428016c38fb55a3b6094334c7c876b71bf4bf3f7
Author: Tobias Rieger <to...@icloud.com>
AuthorDate: Tue Aug 25 10:23:00 2020 +0200

    [SYSTEMDS-2550] Federated Parameter Server
    
    This commit adds federated Parameter server to the system.
    this allows for federated training of neural networks.
    Particularly verified are two architectures, a feed forward NN and
    a Convocational NN.
    
    Closes #1075
---
 .../apache/sysds/hops/recompile/Recompiler.java    |   2 +-
 .../java/org/apache/sysds/parser/Statement.java    |  16 +-
 .../controlprogram/caching/MatrixObject.java       |  21 +-
 .../federated/FederatedWorkerHandler.java          |   6 +-
 .../paramserv/FederatedPSControlThread.java        | 560 +++++++++++++++++++++
 .../controlprogram/paramserv/ParamServer.java      |   5 +
 .../controlprogram/paramserv/ParamservUtils.java   |  47 +-
 .../paramserv/dp/DataPartitionFederatedScheme.java |  88 ++++
 .../paramserv/dp/FederatedDataPartitioner.java     |  46 ++
 .../dp/KeepDataOnWorkerFederatedScheme.java        |  32 ++
 .../paramserv/dp/ShuffleFederatedScheme.java       |  33 ++
 .../sysds/runtime/instructions/cp/ListObject.java  | 145 +++++-
 .../cp/ParamservBuiltinCPInstruction.java          |  98 +++-
 .../sysds/runtime/util/ProgramConverter.java       |  40 +-
 .../component/paramserv/SerializationTest.java     |  86 +++-
 .../paramserv/FederatedParamservTest.java          | 195 +++++++
 .../scripts/functions/federated/paramserv/CNN.dml  | 474 +++++++++++++++++
 .../federated/paramserv/FederatedParamservTest.dml |  57 +++
 .../functions/federated/paramserv/TwoNN.dml        | 299 +++++++++++
 19 files changed, 2185 insertions(+), 65 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index d048863..c785cfc 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -1039,7 +1039,7 @@ public class Recompiler
 		}
 	}
 	
-	private static void rRecompileProgramBlock2Forced( ProgramBlock pb, long tid, HashSet<String> fnStack, ExecType et ) {
+	public static void rRecompileProgramBlock2Forced( ProgramBlock pb, long tid, HashSet<String> fnStack, ExecType et ) {
 		if (pb instanceof WhileProgramBlock)
 		{
 			WhileProgramBlock pbTmp = (WhileProgramBlock)pb;
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java
index f0bdd66..b61b0d6 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -71,7 +71,7 @@ public abstract class Statement implements ParseInfo
 	public static final String PS_MODE = "mode";
 	public static final String PS_GRADIENTS = "gradients";
 	public enum PSModeType {
-		LOCAL, REMOTE_SPARK
+		FEDERATED, LOCAL, REMOTE_SPARK
 	}
 	public static final String PS_UPDATE_TYPE = "utype";
 	public enum PSUpdateType {
@@ -94,12 +94,26 @@ public abstract class Statement implements ParseInfo
 	public enum PSScheme {
 		DISJOINT_CONTIGUOUS, DISJOINT_ROUND_ROBIN, DISJOINT_RANDOM, OVERLAP_RESHUFFLE
 	}
+	public enum FederatedPSScheme {
+		KEEP_DATA_ON_WORKER, SHUFFLE
+	}
 	public static final String PS_HYPER_PARAMS = "hyperparams";
 	public static final String PS_CHECKPOINTING = "checkpointing";
 	public enum PSCheckpointing {
 		NONE, EPOCH, EPOCH10
 	}
 
+	// String constants related to federated parameter server functionality
+	// prefixed with code: "1701-NCC-" to not overwrite anything
+	public static final String PS_FED_BATCH_SIZE = "1701-NCC-batch_size";
+	public static final String PS_FED_DATA_SIZE = "1701-NCC-data_size";
+	public static final String PS_FED_NUM_BATCHES = "1701-NCC-num_batches";
+	public static final String PS_FED_NAMESPACE = "1701-NCC-namespace";
+	public static final String PS_FED_GRADIENTS_FNAME = "1701-NCC-gradients_fname";
+	public static final String PS_FED_AGGREGATION_FNAME = "1701-NCC-aggregation_fname";
+	public static final String PS_FED_BATCHCOUNTER_VARID = "1701-NCC-batchcounter_varid";
+	public static final String PS_FED_MODEL_VARID = "1701-NCC-model_varid";
+
 
 	public abstract boolean controlStatement();
 	
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index 85d8a8f..4fd2b06 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -66,7 +66,7 @@ import org.apache.sysds.runtime.util.IndexRange;
 public class MatrixObject extends CacheableData<MatrixBlock>
 {
 	private static final long serialVersionUID = 6374712373206495637L;
-	
+
 	public enum UpdateType {
 		COPY,
 		INPLACE,
@@ -87,7 +87,7 @@ public class MatrixObject extends CacheableData<MatrixBlock>
 	private int _partitionSize = -1; //indicates n for BLOCKWISE_N
 	private String _partitionCacheName = null; //name of cache block
 	private MatrixBlock _partitionInMemory = null;
-	
+
 	/**
 	 * Constructor that takes the value type and the HDFS filename.
 	 * 
@@ -112,6 +112,23 @@ public class MatrixObject extends CacheableData<MatrixBlock>
 		_cache = null;
 		_data = null;
 	}
+
+	/**
+	 * Constructor that takes the value type, HDFS filename and associated metadata and a MatrixBlock
+	 * used for creation after serialization
+	 *
+	 * @param vt value type
+	 * @param file file name
+	 * @param mtd metadata
+	 * @param data matrix block data
+	 */
+	public MatrixObject( ValueType vt, String file, MetaData mtd, MatrixBlock data) {
+		super (DataType.MATRIX, vt);
+		_metaData = mtd;
+		_hdfsFileName = file;
+		_cache = null;
+		_data = data;
+	}
 	
 	/**
 	 * Copy constructor that copies meta data but NO data.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 6764f12..e932785 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -244,11 +244,15 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 		}
 		
 		//wrap transferred cache block into cacheable data
-		Data data = null;
+		Data data;
 		if( request.getParam(0) instanceof CacheBlock )
 			data = ExecutionContext.createCacheableData((CacheBlock) request.getParam(0));
 		else if( request.getParam(0) instanceof ScalarObject )
 			data = (ScalarObject) request.getParam(0);
+		else if( request.getParam(0) instanceof ListObject )
+			data = (ListObject) request.getParam(0);
+		else
+			throw new DMLRuntimeException("FederatedWorkerHandler: Unsupported object type, has to be of type CacheBlock or ScalarObject");
 		
 		//set variable and construct empty response
 		ec.setVariable(varname, data);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
new file mode 100644
index 0000000..8fa0698
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -0,0 +1,560 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv;
+
+import org.apache.sysds.parser.DataIdentifier;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
+import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysds.runtime.controlprogram.ProgramBlock;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.IntObject;
+import org.apache.sysds.runtime.instructions.cp.ListObject;
+import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.util.ProgramConverter;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+import java.util.stream.Collectors;
+
+import static org.apache.sysds.runtime.util.ProgramConverter.*;
+
+public class FederatedPSControlThread extends PSWorker implements Callable<Void> {
+	FederatedData _featuresData;
+	FederatedData _labelsData;
+	final long _batchCounterVarID;
+	final long _modelVarID;
+	int _totalNumBatches;
+
+	public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps) {
+		super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
+		
+		// generate the IDs for model and batch counter. These get overwritten on the federated worker each time
+		_batchCounterVarID = FederationUtils.getNextFedDataID();
+		_modelVarID = FederationUtils.getNextFedDataID();
+	}
+
+	/**
+	 * Sets up the federated worker and control thread
+	 */
+	public void setup() {
+		// prepare features and labels
+		_features.getFedMapping().forEachParallel((range, data) -> {
+			_featuresData = data;
+			return null;
+		});
+		_labels.getFedMapping().forEachParallel((range, data) -> {
+			_labelsData = data;
+			return null;
+		});
+
+		// calculate number of batches and get data size
+		long dataSize = _features.getNumRows();
+		_totalNumBatches = (int) Math.ceil((double) dataSize / _batchSize);
+
+		// serialize program
+		// create program blocks for the instruction filtering
+		String programSerialized;
+		ArrayList<ProgramBlock> programBlocks = new ArrayList<>();
+
+		BasicProgramBlock gradientProgramBlock = new BasicProgramBlock(_ec.getProgram());
+		gradientProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_inst)));
+		programBlocks.add(gradientProgramBlock);
+
+		if(_freq == Statement.PSFrequency.EPOCH) {
+			BasicProgramBlock aggProgramBlock = new BasicProgramBlock(_ec.getProgram());
+			aggProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_ps.getAggInst())));
+			programBlocks.add(aggProgramBlock);
+		}
+
+		StringBuilder sb = new StringBuilder();
+		sb.append(PROG_BEGIN);
+		sb.append( NEWLINE );
+		sb.append(ProgramConverter.serializeProgram(_ec.getProgram(),
+				programBlocks,
+				new HashMap<>(),
+				false
+		));
+		sb.append(PROG_END);
+		programSerialized = sb.toString();
+
+		// write program and meta data to worker
+		Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+				_featuresData.getVarID(),
+				new setupFederatedWorker(_batchSize,
+						dataSize,
+						_totalNumBatches,
+						programSerialized,
+						_inst.getNamespace(),
+						_inst.getFunctionName(),
+						_ps.getAggInst().getFunctionName(),
+						_ec.getListObject("hyperparams"),
+						_batchCounterVarID,
+						_modelVarID
+				)
+		));
+
+		try {
+			FederatedResponse response = udfResponse.get();
+			if(!response.isSuccessful())
+				throw new DMLRuntimeException("FederatedLocalPSThread: Setup UDF failed");
+		}
+		catch(Exception e) {
+			throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Setup UDF" + e.getMessage());
+		}
+	}
+
+	/**
+	 * Setup UDF executed on the federated worker
+	 */
+	private static class setupFederatedWorker extends FederatedUDF {
+		long _batchSize;
+		long _dataSize;
+		long _numBatches;
+		String _programString;
+		String _namespace;
+		String _gradientsFunctionName;
+		String _aggregationFunctionName;
+		ListObject _hyperParams;
+		long _batchCounterVarID;
+		long _modelVarID;
+
+		protected setupFederatedWorker(long batchSize, long dataSize, long numBatches, String programString, String namespace, String gradientsFunctionName, String aggregationFunctionName, ListObject hyperParams, long batchCounterVarID, long modelVarID) {
+			super(new long[]{});
+			_batchSize = batchSize;
+			_dataSize = dataSize;
+			_numBatches = numBatches;
+			_programString = programString;
+			_namespace = namespace;
+			_gradientsFunctionName = gradientsFunctionName;
+			_aggregationFunctionName = aggregationFunctionName;
+			_hyperParams = hyperParams;
+			_batchCounterVarID = batchCounterVarID;
+			_modelVarID = modelVarID;
+		}
+
+		@Override
+		public FederatedResponse execute(ExecutionContext ec, Data... data) {
+			// parse and set program
+			ec.setProgram(ProgramConverter.parseProgram(_programString, 0, false));
+
+			// set variables to ec
+			ec.setVariable(Statement.PS_FED_BATCH_SIZE, new IntObject(_batchSize));
+			ec.setVariable(Statement.PS_FED_DATA_SIZE, new IntObject(_dataSize));
+			ec.setVariable(Statement.PS_FED_NUM_BATCHES, new IntObject(_numBatches));
+			ec.setVariable(Statement.PS_FED_NAMESPACE, new StringObject(_namespace));
+			ec.setVariable(Statement.PS_FED_GRADIENTS_FNAME, new StringObject(_gradientsFunctionName));
+			ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new StringObject(_aggregationFunctionName));
+			ec.setVariable(Statement.PS_HYPER_PARAMS, _hyperParams);
+			ec.setVariable(Statement.PS_FED_BATCHCOUNTER_VARID, new IntObject(_batchCounterVarID));
+			ec.setVariable(Statement.PS_FED_MODEL_VARID, new IntObject(_modelVarID));
+
+			return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+		}
+	}
+
+	/**
+	 * cleans up the execution context of the federated worker
+	 */
+	public void teardown() {
+		// write program and meta data to worker
+		Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+				_featuresData.getVarID(),
+				new teardownFederatedWorker()
+		));
+
+		try {
+			FederatedResponse response = udfResponse.get();
+			if(!response.isSuccessful())
+				throw new DMLRuntimeException("FederatedLocalPSThread: Teardown UDF failed");
+		}
+		catch(Exception e) {
+			throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Teardown UDF" + e.getMessage());
+		}
+	}
+
+	/**
+	 * Teardown UDF executed on the federated worker
+	 */
+	private static class teardownFederatedWorker extends FederatedUDF {
+		protected teardownFederatedWorker() {
+			super(new long[]{});
+		}
+
+		@Override
+		public FederatedResponse execute(ExecutionContext ec, Data... data) {
+			// remove variables from ec
+			ec.removeVariable(Statement.PS_FED_BATCH_SIZE);
+			ec.removeVariable(Statement.PS_FED_DATA_SIZE);
+			ec.removeVariable(Statement.PS_FED_NUM_BATCHES);
+			ec.removeVariable(Statement.PS_FED_NAMESPACE);
+			ec.removeVariable(Statement.PS_FED_GRADIENTS_FNAME);
+			ec.removeVariable(Statement.PS_FED_AGGREGATION_FNAME);
+			ec.removeVariable(Statement.PS_FED_BATCHCOUNTER_VARID);
+			ec.removeVariable(Statement.PS_FED_MODEL_VARID);
+			ParamservUtils.cleanupListObject(ec, Statement.PS_HYPER_PARAMS);
+			ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
+			
+			return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+		}
+	}
+
+	/**
+	 * Entry point of the functionality
+	 *
+	 * @return void
+	 * @throws Exception incase the execution fails
+	 */
+	@Override
+	public Void call() throws Exception {
+		try {
+			switch (_freq) {
+				case BATCH:
+					computeBatch(_totalNumBatches);
+					break;
+				case EPOCH:
+					computeEpoch();
+					break;
+				default:
+					throw new DMLRuntimeException(String.format("%s not support update frequency %s", getWorkerName(), _freq));
+			}
+		} catch (Exception e) {
+			throw new DMLRuntimeException(String.format("%s failed", getWorkerName()), e);
+		}
+		teardown();
+		return null;
+	}
+
+	protected ListObject pullModel() {
+		// Pull the global parameters from ps
+		return _ps.pull(_workerID);
+	}
+
+	protected void pushGradients(ListObject gradients) {
+		// Push the gradients to ps
+		_ps.push(_workerID, gradients);
+	}
+
+	/**
+	 * Computes all epochs and synchronizes after each batch
+	 *
+	 * @param numBatches the number of batches per epoch
+	 */
+	protected void computeBatch(int numBatches) {
+		for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) {
+			for (int batchCounter = 0; batchCounter < numBatches; batchCounter++) {
+				ListObject model = pullModel();
+				ListObject gradients = computeBatchGradients(model, batchCounter);
+				pushGradients(gradients);
+				ParamservUtils.cleanupListObject(model);
+				ParamservUtils.cleanupListObject(gradients);
+			}
+			System.out.println("[+] " + this.getWorkerName() + " completed epoch " + epochCounter);
+		}
+	}
+
+	/**
+	 * Computes a single specified batch on the federated worker
+	 *
+	 * @param model the current model from the parameter server
+	 * @param batchCounter the current batch number needed for slicing the features and labels
+	 * @return the gradient vector
+	 */
+	protected ListObject computeBatchGradients(ListObject model, int batchCounter) {
+		// put batch counter on federated worker
+		Future<FederatedResponse> putBatchCounterResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _batchCounterVarID, new IntObject(batchCounter)));
+
+		// put current model on federated worker
+		Future<FederatedResponse> putParamsResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _modelVarID, model));
+
+		try {
+			if(!putParamsResponse.get().isSuccessful() || !putBatchCounterResponse.get().isSuccessful())
+				throw new DMLRuntimeException("FederatedLocalPSThread: put was not successful");
+		}
+		catch(Exception e) {
+			throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute put" + e.getMessage());
+		}
+
+		// create and execute the udf on the remote worker
+		Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+				_featuresData.getVarID(),
+				new federatedComputeBatchGradients(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), _batchCounterVarID, _modelVarID})
+		));
+
+		try {
+			Object[] responseData = udfResponse.get().getData();
+			return (ListObject) responseData[0];
+		}
+		catch(Exception e) {
+			throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute UDF" + e.getMessage());
+		}
+	}
+
+	/**
+	 * This is the code that will be executed on the federated Worker when computing a single batch
+	 */
+	private static class federatedComputeBatchGradients extends FederatedUDF {
+		protected federatedComputeBatchGradients(long[] inIDs) {
+			super(inIDs);
+		}
+
+		@Override
+		public FederatedResponse execute(ExecutionContext ec, Data... data) {
+			// read in data by varid
+			MatrixObject features = (MatrixObject) data[0];
+			MatrixObject labels = (MatrixObject) data[1];
+			long batchCounter = ((IntObject) data[2]).getLongValue();
+			ListObject model = (ListObject) data[3];
+
+			// get data from execution context
+			long batchSize = ((IntObject) ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue();
+			long dataSize = ((IntObject) ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue();
+			String namespace = ((StringObject) ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue();
+			String gradientsFunctionName = ((StringObject) ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
+
+			// slice batch from feature and label matrix
+			long begin = batchCounter * batchSize + 1;
+			long end = Math.min((batchCounter + 1) * batchSize, dataSize);
+			MatrixObject bFeatures = ParamservUtils.sliceMatrix(features, begin, end);
+			MatrixObject bLabels = ParamservUtils.sliceMatrix(labels, begin, end);
+
+			// prepare execution context
+			ec.setVariable(Statement.PS_MODEL, model);
+			ec.setVariable(Statement.PS_FEATURES, bFeatures);
+			ec.setVariable(Statement.PS_LABELS, bLabels);
+
+			// recreate gradient instruction and output
+			FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(namespace, gradientsFunctionName, false);
+			ArrayList<DataIdentifier> inputs = func.getInputParams();
+			ArrayList<DataIdentifier> outputs = func.getOutputParams();
+			CPOperand[] boundInputs = inputs.stream()
+					.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+					.toArray(CPOperand[]::new);
+			ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
+					.collect(Collectors.toCollection(ArrayList::new));
+			Instruction gradientsInstruction = new FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs,
+					func.getInputParamNames(), outputNames, "gradient function");
+			DataIdentifier gradientsOutput = outputs.get(0);
+
+			// calculate and gradients
+			gradientsInstruction.processInstruction(ec);
+			ListObject gradients = ec.getListObject(gradientsOutput.getName());
+
+			// clean up sliced batch
+			ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
+			ParamservUtils.cleanupData(ec, Statement.PS_FEATURES);
+			ParamservUtils.cleanupData(ec, Statement.PS_LABELS);
+
+			// model clean up - doing this twice is not an issue
+			ParamservUtils.cleanupListObject(ec, ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
+			ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL);
+
+			// return
+			return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, gradients);
+		}
+	}
+
+	/**
+	 * Computes all epochs and synchronizes after each one
+	 */
+	protected void computeEpoch() {
+		for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) {
+			// Pull the global parameters from ps
+			ListObject model = pullModel();
+			ListObject gradients = computeEpochGradients(model);
+			pushGradients(gradients);
+			System.out.println("[+] " + this.getWorkerName() + " completed epoch " + epochCounter);
+			ParamservUtils.cleanupListObject(model);
+			ParamservUtils.cleanupListObject(gradients);
+		}
+	}
+
+	/**
+	 * Computes one epoch on the federated worker and updates the model local
+	 *
+	 * @param model the current model from the parameter server
+	 * @return the gradient vector
+	 */
+	protected ListObject computeEpochGradients(ListObject model) {
+		// put current model on federated worker
+		Future<FederatedResponse> putParamsResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _modelVarID, model));
+
+		try {
+			if(!putParamsResponse.get().isSuccessful())
+				throw new DMLRuntimeException("FederatedLocalPSThread: put was not successful");
+		}
+		catch(Exception e) {
+			throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute put" + e.getMessage());
+		}
+
+		// create and execute the udf on the remote worker
+		Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+				_featuresData.getVarID(),
+				new federatedComputeEpochGradients(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), _modelVarID})
+		));
+
+		try {
+			Object[] responseData = udfResponse.get().getData();
+			return (ListObject) responseData[0];
+		}
+		catch(Exception e) {
+			throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute UDF" + e.getMessage());
+		}
+	}
+
+	/**
+	 * This is the code that will be executed on the federated Worker when computing one epoch
+	 */
+	private static class federatedComputeEpochGradients extends FederatedUDF {
+		protected federatedComputeEpochGradients(long[] inIDs) {
+			super(inIDs);
+		}
+
+		@Override
+		public FederatedResponse execute(ExecutionContext ec, Data... data) {
+			// read in data by varid
+			MatrixObject features = (MatrixObject) data[0];
+			MatrixObject labels = (MatrixObject) data[1];
+			ListObject model = (ListObject) data[2];
+
+			// get data from execution context
+			long batchSize = ((IntObject) ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue();
+			long dataSize = ((IntObject) ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue();
+			long numBatches = ((IntObject) ec.getVariable(Statement.PS_FED_NUM_BATCHES)).getLongValue();
+			String namespace = ((StringObject) ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue();
+			String gradientsFunctionName = ((StringObject) ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
+			String aggregationFuctionName = ((StringObject) ec.getVariable(Statement.PS_FED_AGGREGATION_FNAME)).getStringValue();
+
+			// recreate gradient instruction and output
+			FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(namespace, gradientsFunctionName, false);
+			ArrayList<DataIdentifier> inputs = func.getInputParams();
+			ArrayList<DataIdentifier> outputs = func.getOutputParams();
+			CPOperand[] boundInputs = inputs.stream()
+					.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+					.toArray(CPOperand[]::new);
+			ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
+					.collect(Collectors.toCollection(ArrayList::new));
+			Instruction gradientsInstruction = new FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs,
+					func.getInputParamNames(), outputNames, "gradient function");
+			DataIdentifier gradientsOutput = outputs.get(0);
+
+			// recreate aggregation instruction and output
+			func = ec.getProgram().getFunctionProgramBlock(namespace, aggregationFuctionName, false);
+			inputs = func.getInputParams();
+			outputs = func.getOutputParams();
+			boundInputs = inputs.stream()
+					.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+					.toArray(CPOperand[]::new);
+			outputNames = outputs.stream().map(DataIdentifier::getName)
+					.collect(Collectors.toCollection(ArrayList::new));
+			Instruction aggregationInstruction = new FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs,
+					func.getInputParamNames(), outputNames, "aggregation function");
+			DataIdentifier aggregationOutput = outputs.get(0);
+
+
+			ListObject accGradients = null;
+			// prepare execution context
+			ec.setVariable(Statement.PS_MODEL, model);
+			for (int batchCounter = 0; batchCounter < numBatches; batchCounter++) {
+				// slice batch from feature and label matrix
+				long begin = batchCounter * batchSize + 1;
+				long end = Math.min((batchCounter + 1) * batchSize, dataSize);
+				MatrixObject bFeatures = ParamservUtils.sliceMatrix(features, begin, end);
+				MatrixObject bLabels = ParamservUtils.sliceMatrix(labels, begin, end);
+
+				// prepare execution context
+				ec.setVariable(Statement.PS_FEATURES, bFeatures);
+				ec.setVariable(Statement.PS_LABELS, bLabels);
+				boolean localUpdate = batchCounter < numBatches - 1;
+
+				// calculate intermediate gradients
+				gradientsInstruction.processInstruction(ec);
+				ListObject gradients = ec.getListObject(gradientsOutput.getName());
+
+				// TODO: is this equivalent for momentum based and AMS prob?
+				accGradients = ParamservUtils.accrueGradients(accGradients, gradients, false);
+
+				// Update the local model with gradients
+				if(localUpdate) {
+					// Invoke the aggregate function
+					aggregationInstruction.processInstruction(ec);
+					// Get the new model
+					model = ec.getListObject(aggregationOutput.getName());
+					// Set new model in execution context
+					ec.setVariable(Statement.PS_MODEL, model);
+					// clean up gradients and result
+					ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
+					ParamservUtils.cleanupListObject(ec, aggregationOutput.getName());
+				}
+
+				// clean up sliced batch
+				ParamservUtils.cleanupData(ec, Statement.PS_FEATURES);
+				ParamservUtils.cleanupData(ec, Statement.PS_LABELS);
+			}
+
+			// model clean up - doing this twice is not an issue
+			ParamservUtils.cleanupListObject(ec, ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
+			ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL);
+
+			return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, accGradients);
+		}
+	}
+
+	// Statistics methods
+	@Override
+	public String getWorkerName() {
+		return String.format("Federated worker_%d", _workerID);
+	}
+
+	@Override
+	protected void incWorkerNumber() {
+
+	}
+
+	@Override
+	protected void accLocalModelUpdateTime(Timing time) {
+
+	}
+
+	@Override
+	protected void accBatchIndexingTime(Timing time) {
+
+	}
+
+	@Override
+	protected void accGradientComputeTime(Timing time) {
+
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 276f56c..e420ed8 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -57,6 +57,7 @@ public abstract class ParamServer
 	//aggregation service
 	protected ExecutionContext _ec;
 	private Statement.PSUpdateType _updateType;
+
 	private FunctionCallCPInstruction _inst;
 	private String _outputName;
 	private boolean[] _finishedStates;  // Workers' finished states
@@ -232,4 +233,8 @@ public abstract class ParamServer
 		if (DMLScript.STATISTICS)
 			Statistics.accPSModelBroadcastTime((long) tBroad.stop());
 	}
+
+	public FunctionCallCPInstruction getAggInst() {
+		return _inst;
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index 968cb1d..e63fb14 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -31,6 +31,7 @@ import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.MultiThreadedHop;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.lops.LopProperties;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.DMLTranslator;
 import org.apache.sysds.parser.Statement;
@@ -214,15 +215,21 @@ public class ParamservUtils {
 	}
 
 	public static ExecutionContext createExecutionContext(ExecutionContext ec,
-		LocalVariableMap varsMap, String updFunc, String aggFunc, int k)
+	  	LocalVariableMap varsMap, String updFunc, String aggFunc, int k)
+	{
+		return createExecutionContext(ec, varsMap, updFunc, aggFunc, k, false);
+	}
+
+	public static ExecutionContext createExecutionContext(ExecutionContext ec,
+		LocalVariableMap varsMap, String updFunc, String aggFunc, int k, boolean forceExecTypeCP)
 	{
 		Program prog = ec.getProgram();
 
 		// 1. Recompile the internal program blocks 
-		recompileProgramBlocks(k, prog.getProgramBlocks());
+		recompileProgramBlocks(k, prog.getProgramBlocks(), forceExecTypeCP);
 		// 2. Recompile the imported function blocks
 		prog.getFunctionProgramBlocks(false)
-			.forEach((fname, fvalue) -> recompileProgramBlocks(k, fvalue.getChildBlocks()));
+			.forEach((fname, fvalue) -> recompileProgramBlocks(k, fvalue.getChildBlocks(), forceExecTypeCP));
 
 		// 3. Copy all functions 
 		return ExecutionContextFactory.createContext(
@@ -249,6 +256,10 @@ public class ParamservUtils {
 	}
 
 	public static void recompileProgramBlocks(int k, List<ProgramBlock> pbs) {
+		recompileProgramBlocks(k, pbs, false);
+	}
+
+	public static void recompileProgramBlocks(int k, List<ProgramBlock> pbs, boolean forceExecTypeCP) {
 		// Reset the visit status from root
 		for (ProgramBlock pb : pbs)
 			DMLTranslator.resetHopsDAGVisitStatus(pb.getStatementBlock());
@@ -256,43 +267,49 @@ public class ParamservUtils {
 		// Should recursively assign the level of parallelism
 		// and recompile the program block
 		try {
-			rAssignParallelism(pbs, k, false);
+			if(forceExecTypeCP)
+				rAssignParallelismAndRecompile(pbs, k, true, forceExecTypeCP);
+			else
+				rAssignParallelismAndRecompile(pbs, k, false, forceExecTypeCP);
 		} catch (IOException e) {
 			throw new DMLRuntimeException(e);
 		}
 	}
 
-	private static boolean rAssignParallelism(List<ProgramBlock> pbs, int k, boolean recompiled) throws IOException {
+	private static boolean rAssignParallelismAndRecompile(List<ProgramBlock> pbs, int k, boolean recompiled, boolean forceExecTypeCP) throws IOException {
 		for (ProgramBlock pb : pbs) {
 			if (pb instanceof ParForProgramBlock) {
 				ParForProgramBlock pfpb = (ParForProgramBlock) pb;
 				pfpb.setDegreeOfParallelism(k);
-				recompiled |= rAssignParallelism(pfpb.getChildBlocks(), 1, recompiled);
+				recompiled |= rAssignParallelismAndRecompile(pfpb.getChildBlocks(), 1, recompiled, forceExecTypeCP);
 			} else if (pb instanceof ForProgramBlock) {
-				recompiled |= rAssignParallelism(((ForProgramBlock) pb).getChildBlocks(), k, recompiled);
+				recompiled |= rAssignParallelismAndRecompile(((ForProgramBlock) pb).getChildBlocks(), k, recompiled, forceExecTypeCP);
 			} else if (pb instanceof WhileProgramBlock) {
-				recompiled |= rAssignParallelism(((WhileProgramBlock) pb).getChildBlocks(), k, recompiled);
+				recompiled |= rAssignParallelismAndRecompile(((WhileProgramBlock) pb).getChildBlocks(), k, recompiled, forceExecTypeCP);
 			} else if (pb instanceof FunctionProgramBlock) {
-				recompiled |= rAssignParallelism(((FunctionProgramBlock) pb).getChildBlocks(), k, recompiled);
+				recompiled |= rAssignParallelismAndRecompile(((FunctionProgramBlock) pb).getChildBlocks(), k, recompiled, forceExecTypeCP);
 			} else if (pb instanceof IfProgramBlock) {
 				IfProgramBlock ipb = (IfProgramBlock) pb;
-				recompiled |= rAssignParallelism(ipb.getChildBlocksIfBody(), k, recompiled);
+				recompiled |= rAssignParallelismAndRecompile(ipb.getChildBlocksIfBody(), k, recompiled, forceExecTypeCP);
 				if (ipb.getChildBlocksElseBody() != null)
-					recompiled |= rAssignParallelism(ipb.getChildBlocksElseBody(), k, recompiled);
+					recompiled |= rAssignParallelismAndRecompile(ipb.getChildBlocksElseBody(), k, recompiled, forceExecTypeCP);
 			} else {
 				StatementBlock sb = pb.getStatementBlock();
 				for (Hop hop : sb.getHops())
-					recompiled |= rAssignParallelism(hop, k, recompiled);
+					recompiled |= rAssignParallelismAndRecompile(hop, k, recompiled);
 			}
 			// Recompile the program block
 			if (recompiled) {
-				Recompiler.recompileProgramBlockInstructions(pb);
+				if(forceExecTypeCP)
+					Recompiler.rRecompileProgramBlock2Forced(pb, pb.getThreadID(), new HashSet<>(), LopProperties.ExecType.CP);
+				else
+					Recompiler.recompileProgramBlockInstructions(pb);
 			}
 		}
 		return recompiled;
 	}
 
-	private static boolean rAssignParallelism(Hop hop, int k, boolean recompiled) {
+	private static boolean rAssignParallelismAndRecompile(Hop hop, int k, boolean recompiled) {
 		if (hop.isVisited()) {
 			return recompiled;
 		}
@@ -304,7 +321,7 @@ public class ParamservUtils {
 		}
 		ArrayList<Hop> inputs = hop.getInput();
 		for (Hop h : inputs) {
-			recompiled |= rAssignParallelism(h, k, recompiled);
+			recompiled |= rAssignParallelismAndRecompile(h, k, recompiled);
 		}
 		hop.setVisited();
 		return recompiled;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
new file mode 100644
index 0000000..4183372
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.lops.compile.Dag;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+
+public abstract class DataPartitionFederatedScheme {
+
+	public static final class Result {
+		public final List<MatrixObject> pFeatures;
+		public final List<MatrixObject> pLabels;
+		public final int workerNum;
+
+		public Result(List<MatrixObject> pFeatures, List<MatrixObject> pLabels, int workerNum) {
+			this.pFeatures = pFeatures;
+			this.pLabels = pLabels;
+			this.workerNum = workerNum;
+		}
+	}
+
+	public abstract Result doPartitioning(MatrixObject features, MatrixObject labels);
+
+	/**
+	 * Takes a row federated Matrix and slices it into a matrix for each worker
+	 *
+	 * @param fedMatrix the federated input matrix
+	 */
+	static List<MatrixObject> sliceFederatedMatrix(MatrixObject fedMatrix) {
+		if (fedMatrix.isFederated(FederationMap.FType.ROW)) {
+
+			List<MatrixObject> slices = Collections.synchronizedList(new ArrayList<>());
+			fedMatrix.getFedMapping().forEachParallel((range, data) -> {
+				// Create sliced matrix object
+				MatrixObject slice = new MatrixObject(fedMatrix.getValueType(), Dag.getNextUniqueVarname(Types.DataType.MATRIX));
+				// Warning needs MetaDataFormat instead of MetaData
+				slice.setMetaData(new MetaDataFormat(
+						new MatrixCharacteristics(range.getSize(0), range.getSize(1)),
+						Types.FileFormat.BINARY)
+				);
+
+				// Create new federation map
+				HashMap<FederatedRange, FederatedData> newFedHashMap = new HashMap<>();
+				newFedHashMap.put(range, data);
+				slice.setFedMapping(new FederationMap(fedMatrix.getFedMapping().getID(), newFedHashMap));
+				slice.getFedMapping().setType(FederationMap.FType.ROW);
+
+				slices.add(slice);
+				return null;
+			});
+
+			return slices;
+		}
+		else {
+			throw new DMLRuntimeException("Federated data partitioner: " +
+					"currently only supports row federated data");
+		}
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
new file mode 100644
index 0000000..4cdfb95
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+
+public class FederatedDataPartitioner {
+
+	private final DataPartitionFederatedScheme _scheme;
+
+	public FederatedDataPartitioner(Statement.FederatedPSScheme scheme) {
+		switch (scheme) {
+			case KEEP_DATA_ON_WORKER:
+				_scheme = new KeepDataOnWorkerFederatedScheme();
+				break;
+			case SHUFFLE:
+				_scheme = new ShuffleFederatedScheme();
+				break;
+			default:
+				throw new DMLRuntimeException(String.format("FederatedDataPartitioner: not support data partition scheme '%s'", scheme));
+		}
+	}
+
+	public DataPartitionFederatedScheme.Result doPartitioning(MatrixObject features, MatrixObject labels) {
+		return _scheme.doPartitioning(features, labels);
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
new file mode 100644
index 0000000..06feded
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import java.util.List;
+
+public class KeepDataOnWorkerFederatedScheme extends DataPartitionFederatedScheme {
+	@Override
+	public Result doPartitioning(MatrixObject features, MatrixObject labels) {
+		List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
+		List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+		return new Result(pFeatures, pLabels, pFeatures.size());
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
new file mode 100644
index 0000000..d6d8cfc
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+
+import java.util.List;
+
+public class ShuffleFederatedScheme extends DataPartitionFederatedScheme {
+	@Override
+	public Result doPartitioning(MatrixObject features, MatrixObject labels) {
+		List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
+		List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+		return new Result(pFeatures, pLabels, pFeatures.size());
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index 4f726ee..a8397cb 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -19,17 +19,29 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 
+import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.compile.Dag;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 
-public class ListObject extends Data {
+public class ListObject extends Data implements Externalizable {
 	private static final long serialVersionUID = 3652422061598967358L;
 
 	private final List<Data> _data;
@@ -37,6 +49,14 @@ public class ListObject extends Data {
 	private List<String> _names = null;
 	private int _nCacheable;
 	private List<LineageItem> _lineage = null;
+
+	/*
+	 * No op constructor for Externalizable interface
+	 */
+	public ListObject() {
+		super(DataType.LIST, ValueType.UNKNOWN);
+		_data = new ArrayList<>();
+	}
 	
 	public ListObject(List<Data> data) {
 		this(data, null, null);
@@ -286,4 +306,127 @@ public class ListObject extends Data {
 		sb.append(")");
 		return sb.toString();
 	}
+
+	/**
+	 * Redirects the default java serialization via externalizable to our default
+	 * hadoop writable serialization for efficient broadcast/rdd serialization.
+	 *
+	 * @param out object output
+	 * @throws IOException if IOException occurs
+	 */
+	@Override
+	public void writeExternal(ObjectOutput out) throws IOException {
+		// write out length
+		out.writeInt(getLength());
+		// write out num cacheable
+		out.writeInt(_nCacheable);
+
+		// write out names for named list
+		out.writeBoolean(getNames() != null);
+		if(getNames() != null) {
+			for (int i = 0; i < getLength(); i++) {
+				out.writeObject(_names.get(i));
+			}
+		}
+
+		// write out data
+		for(int i = 0; i < getLength(); i++) {
+			Data d = getData(i);
+			out.writeObject(d.getDataType());
+			out.writeObject(d.getValueType());
+			out.writeObject(d.getPrivacyConstraint());
+			switch(d.getDataType()) {
+				case LIST:
+					ListObject lo = (ListObject) d;
+					out.writeObject(lo);
+					break;
+				case MATRIX:
+					MatrixObject mo = (MatrixObject) d;
+					MetaDataFormat md = (MetaDataFormat) mo.getMetaData();
+					DataCharacteristics dc = md.getDataCharacteristics();
+
+					out.writeObject(dc.getRows());
+					out.writeObject(dc.getCols());
+					out.writeObject(dc.getBlocksize());
+					out.writeObject(dc.getNonZeros());
+					out.writeObject(md.getFileFormat());
+					out.writeObject(mo.acquireReadAndRelease());
+					break;
+				case SCALAR:
+					ScalarObject so = (ScalarObject) d;
+					out.writeObject(so.getStringValue());
+					break;
+				default:
+					throw new DMLRuntimeException("Unable to serialize datatype " + dataType);
+			}
+		}
+	}
+
+	/**
+	 * Redirects the default java serialization via externalizable to our default
+	 * hadoop writable serialization for efficient broadcast/rdd deserialization.
+	 *
+	 * @param in object input
+	 * @throws IOException if IOException occurs
+	 */
+	@Override
+	public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+		// read in length
+		int length = in.readInt();
+		// read in num cacheable
+		_nCacheable = in.readInt();
+
+		// read in names
+		Boolean names = in.readBoolean();
+		if(names) {
+			_names = new ArrayList<>();
+			for (int i = 0; i < length; i++) {
+				_names.add((String) in.readObject());
+			}
+		}
+
+		// read in data
+		for(int i = 0; i < length; i++) {
+			DataType dataType = (DataType) in.readObject();
+			ValueType valueType = (ValueType) in.readObject();
+			PrivacyConstraint privacyConstraint = (PrivacyConstraint) in.readObject();
+			Data d;
+			switch(dataType) {
+				case LIST:
+					d = (ListObject) in.readObject();
+					break;
+				case MATRIX:
+					long rows = (long) in.readObject();
+					long cols = (long) in.readObject();
+					int blockSize = (int) in.readObject();
+					long nonZeros = (long) in.readObject();
+					Types.FileFormat fileFormat = (Types.FileFormat) in.readObject();
+
+					// construct objects and set meta data
+					MatrixCharacteristics matrixCharacteristics = new MatrixCharacteristics(rows, cols, blockSize, nonZeros);
+					MetaDataFormat metaDataFormat = new MetaDataFormat(matrixCharacteristics, fileFormat);
+					MatrixBlock matrixBlock = (MatrixBlock) in.readObject();
+
+					d = new MatrixObject(valueType, Dag.getNextUniqueVarname(Types.DataType.MATRIX), metaDataFormat, matrixBlock);
+					break;
+				case SCALAR:
+					String value = (String) in.readObject();
+					ScalarObject so;
+					switch (valueType) {
+						case INT64:     so = new IntObject(Long.parseLong(value)); break;
+						case FP64:  	so = new DoubleObject(Double.parseDouble(value)); break;
+						case BOOLEAN: 	so = new BooleanObject(Boolean.parseBoolean(value)); break;
+						case STRING:  	so = new StringObject(value); break;
+						default:
+							throw new DMLRuntimeException("Unable to parse valuetype " + valueType);
+					}
+					d = so;
+					break;
+				default:
+					throw new DMLRuntimeException("Unable to deserialize datatype " + dataType);
+			}
+			d.setPrivacyConstraints(privacyConstraint);
+			_data.add(d);
+		}
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index c42ec91..5e8ad32 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -52,6 +52,7 @@ import org.apache.spark.util.LongAccumulator;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.hops.recompile.Recompiler;
 import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.parser.Statement;
 import org.apache.sysds.parser.Statement.PSFrequency;
 import org.apache.sysds.parser.Statement.PSModeType;
 import org.apache.sysds.parser.Statement.PSScheme;
@@ -61,13 +62,16 @@ import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.controlprogram.paramserv.FederatedPSControlThread;
 import org.apache.sysds.runtime.controlprogram.paramserv.LocalPSWorker;
 import org.apache.sysds.runtime.controlprogram.paramserv.LocalParamServer;
 import org.apache.sysds.runtime.controlprogram.paramserv.ParamServer;
 import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSBody;
 import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSWorker;
+import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme;
 import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
+import org.apache.sysds.runtime.controlprogram.paramserv.dp.FederatedDataPartitioner;
 import org.apache.sysds.runtime.controlprogram.paramserv.dp.LocalDataPartitioner;
 import org.apache.sysds.runtime.controlprogram.paramserv.rpc.PSRpcFactory;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
@@ -91,16 +95,87 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 
 	@Override
 	public void processInstruction(ExecutionContext ec) {
-		PSModeType mode = getPSMode();
-		switch (mode) {
-			case LOCAL:
-				runLocally(ec, mode);
-				break;
-			case REMOTE_SPARK:
-				runOnSpark((SparkExecutionContext) ec, mode);
-				break;
-			default:
-				throw new DMLRuntimeException(String.format("Paramserv func: not support mode %s", mode));
+		// check if the input is federated
+		if(ec.getMatrixObject(getParam(PS_FEATURES)).isFederated() ||
+				ec.getMatrixObject(getParam(PS_LABELS)).isFederated()) {
+			runFederated(ec);
+		}
+		// if not federated check mode
+		else {
+			PSModeType mode = getPSMode();
+			switch (mode) {
+				case LOCAL:
+					runLocally(ec, mode);
+					break;
+				case REMOTE_SPARK:
+					runOnSpark((SparkExecutionContext) ec, mode);
+					break;
+				default:
+					throw new DMLRuntimeException(String.format("Paramserv func: not support mode %s", mode));
+			}
+		}
+	}
+
+	private void runFederated(ExecutionContext ec) {
+		System.out.println("PARAMETER SERVER");
+		System.out.println("[+] Running in federated mode");
+
+		// get inputs
+		PSFrequency freq = getFrequency();
+		PSUpdateType updateType = getUpdateType();
+		String updFunc = getParam(PS_UPDATE_FUN);
+		String aggFunc = getParam(PS_AGGREGATION_FUN);
+
+		// partition federated data
+		DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(Statement.FederatedPSScheme.KEEP_DATA_ON_WORKER)
+				.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), ec.getMatrixObject(getParam(PS_LABELS)));
+		List<MatrixObject> pFeatures = result.pFeatures;
+		List<MatrixObject> pLabels = result.pLabels;
+		int workerNum = result.workerNum;
+
+		// setup threading
+		BasicThreadFactory factory = new BasicThreadFactory.Builder()
+				.namingPattern("workers-pool-thread-%d").build();
+		ExecutorService es = Executors.newFixedThreadPool(workerNum, factory);
+
+		// Get the compiled execution context
+		LocalVariableMap newVarsMap = createVarsMap(ec);
+		// Level of par is 1 because one worker will be launched per task
+		// TODO: Fix recompilation
+		ExecutionContext newEC = ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, 1, true);
+		// Create workers' execution context
+		List<ExecutionContext> federatedWorkerECs = ParamservUtils.copyExecutionContext(newEC, workerNum);
+		// Create the agg service's execution context
+		ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0);
+		// Create the parameter server
+		ListObject model = ec.getListObject(getParam(PS_MODEL));
+		ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, updateType, workerNum, model, aggServiceEC);
+		// Create the local workers
+		List<FederatedPSControlThread> threads = IntStream.range(0, workerNum)
+				.mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, getEpochs(), getBatchSize(), federatedWorkerECs.get(i), ps))
+				.collect(Collectors.toList());
+
+		if(workerNum != threads.size()) {
+			throw new DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning does not match threads!");
+		}
+
+		// Set features and lables for the control threads and write the program and instructions and hyperparams to the federated workers
+		for (int i = 0; i < threads.size(); i++) {
+			threads.get(i).setFeatures(pFeatures.get(i));
+			threads.get(i).setLabels(pLabels.get(i));
+			threads.get(i).setup();
+		}
+
+		try {
+			// Launch the worker threads and wait for completion
+			for (Future<Void> ret : es.invokeAll(threads))
+				ret.get(); //error handling
+			// Fetch the final model from ps
+			ec.setVariable(output.getName(), ps.getResult());
+		} catch (InterruptedException | ExecutionException e) {
+			throw new DMLRuntimeException("ParamservBuiltinCPInstruction: unknown error: ", e);
+		} finally {
+			es.shutdownNow();
 		}
 	}
 
@@ -150,7 +225,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		LongAccumulator aEpoch = sec.getSparkContext().sc().longAccumulator("numEpochs");
 		
 		// Create remote workers
-		SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), 
+		SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN),
 			getFrequency(), getEpochs(), getBatchSize(), program, clsMap, sec.getSparkContext().getConf(),
 			server.getPort(), aSetup, aWorker, aUpdate, aIndex, aGrad, aRPC, aBatch, aEpoch);
 
@@ -333,6 +408,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 	 */
 	private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType, int workerNum, ListObject model, ExecutionContext ec) {
 		switch (mode) {
+			case FEDERATED:
 			case LOCAL:
 			case REMOTE_SPARK:
 				return LocalParamServer.create(model, aggFunc, updateType, ec, workerNum);
diff --git a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
index 7dad319..16472c7 100644
--- a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
@@ -817,7 +817,7 @@ public class ProgramConverter
 		//handle program
 		sb.append(PROG_BEGIN);
 		sb.append( NEWLINE );
-		sb.append( serializeProgram(prog, pbs, clsMap) );
+		sb.append( serializeProgram(prog, pbs, clsMap, true) );
 		sb.append(PROG_END);
 		sb.append( NEWLINE );
 		sb.append( COMPONENTS_DELIM );
@@ -849,32 +849,32 @@ public class ProgramConverter
 		return sb.toString();
 	}
 
-	private static String serializeProgram( Program prog, ArrayList<ProgramBlock> pbs, HashMap<String, byte[]> clsMap ) {
-		//note program contains variables, programblocks and function program blocks 
+	public static String serializeProgram( Program prog, ArrayList<ProgramBlock> pbs, HashMap<String, byte[]> clsMap, boolean opt) {
+		//note program contains variables, programblocks and function program blocks
 		//but in order to avoid redundancy, we only serialize function program blocks
-		HashMap<String, FunctionProgramBlock> fpb = prog.getFunctionProgramBlocks();
+		HashMap<String, FunctionProgramBlock> fpb = prog.getFunctionProgramBlocks(opt);
 		HashSet<String> cand = new HashSet<>();
-		rFindSerializationCandidates(pbs, cand);
-		return rSerializeFunctionProgramBlocks( fpb, cand, clsMap );
+		rFindSerializationCandidates(pbs, cand, opt);
+		return rSerializeFunctionProgramBlocks(fpb, cand, clsMap);
 	}
 
-	private static void rFindSerializationCandidates( ArrayList<ProgramBlock> pbs, HashSet<String> cand ) 
+	private static void rFindSerializationCandidates( ArrayList<ProgramBlock> pbs, HashSet<String> cand, boolean opt)
 	{
 		for( ProgramBlock pb : pbs )
 		{
 			if( pb instanceof WhileProgramBlock ) {
 				WhileProgramBlock wpb = (WhileProgramBlock) pb;
-				rFindSerializationCandidates(wpb.getChildBlocks(), cand );
+				rFindSerializationCandidates(wpb.getChildBlocks(), cand, opt);
 			}
 			else if ( pb instanceof ForProgramBlock || pb instanceof ParForProgramBlock ) {
 				ForProgramBlock fpb = (ForProgramBlock) pb; 
-				rFindSerializationCandidates(fpb.getChildBlocks(), cand);
+				rFindSerializationCandidates(fpb.getChildBlocks(), cand, opt);
 			}
 			else if ( pb instanceof IfProgramBlock ) {
 				IfProgramBlock ipb = (IfProgramBlock) pb;
-				rFindSerializationCandidates(ipb.getChildBlocksIfBody(), cand);
+				rFindSerializationCandidates(ipb.getChildBlocksIfBody(), cand, opt);
 				if( ipb.getChildBlocksElseBody() != null )
-					rFindSerializationCandidates(ipb.getChildBlocksElseBody(), cand);
+					rFindSerializationCandidates(ipb.getChildBlocksElseBody(), cand, opt);
 			}
 			else if( pb instanceof BasicProgramBlock ) { 
 				BasicProgramBlock bpb = (BasicProgramBlock) pb;
@@ -885,8 +885,8 @@ public class ProgramConverter
 						if( !cand.contains(fkey) ) { //memoization for multiple calls, recursion
 							cand.add( fkey ); //add to candidates
 							//investigate chains of function calls
-							FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(fci.getNamespace(), fci.getFunctionName());
-							rFindSerializationCandidates(fpb.getChildBlocks(), cand);
+							FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(fci.getNamespace(), fci.getFunctionName(), opt);
+							rFindSerializationCandidates(fpb.getChildBlocks(), cand, opt);
 						}
 					}
 			}
@@ -985,12 +985,12 @@ public class ProgramConverter
 	}
 
 	@SuppressWarnings("all")
-	private static String serializeInstructions( ArrayList<Instruction> inst, HashMap<String, byte[]> clsMap ) 
+	private static String serializeInstructions( ArrayList<Instruction> inst, HashMap<String, byte[]> clsMap )
 	{
 		StringBuilder sb = new StringBuilder();
 		int count = 0;
 		for( Instruction linst : inst ) {
-			//check that only cp instruction are transmitted 
+			//check that only cp instruction are transmitted
 			if( !( linst instanceof CPInstruction) )
 				throw new DMLRuntimeException( NOT_SUPPORTED_SPARK_INSTRUCTION + " " +linst.getClass().getName()+"\n"+linst );
 			
@@ -1098,7 +1098,6 @@ public class ProgramConverter
 				continue;
 			if( count>0 ) {
 				sb.append( ELEMENT_DELIM );
-				sb.append( NEWLINE );
 			}
 			sb.append( pb.getKey() );
 			sb.append( KEY_VALUE_DELIM );
@@ -1115,7 +1114,6 @@ public class ProgramConverter
 		for( ProgramBlock pb : pbs ) {
 			if( count>0 ) {
 				sb.append( ELEMENT_DELIM );
-				sb.append(NEWLINE);
 			}
 			sb.append( rSerializeProgramBlock(pb, clsMap) );
 			count++;
@@ -1339,6 +1337,10 @@ public class ProgramConverter
 	}
 
 	public static Program parseProgram( String in, int id ) {
+		return parseProgram(in, id, true);
+	}
+
+	public static Program parseProgram( String in, int id, boolean opt ) {
 		String lin = in.substring( PROG_BEGIN.length(),in.length()- PROG_END.length()).trim();
 		Program prog = new Program();
 		HashMap<String,FunctionProgramBlock> fc = parseFunctionProgramBlocks(lin, prog, id);
@@ -1346,7 +1348,7 @@ public class ProgramConverter
 			String[] keypart = e.getKey().split( Program.KEY_DELIM );
 			String namespace = keypart[0];
 			String name      = keypart[1];
-			prog.addFunctionProgramBlock(namespace, name, e.getValue());
+			prog.addFunctionProgramBlock(namespace, name, e.getValue(), opt);
 		}
 		return prog;
 	}
@@ -1354,7 +1356,7 @@ public class ProgramConverter
 	private static LocalVariableMap parseVariables(String in) {
 		LocalVariableMap ret = null;
 		if( in.length()> VARS_BEGIN.length() + VARS_END.length()) {
-			String varStr = in.substring( VARS_BEGIN.length(),in.length()- VARS_END.length()).trim();
+			String varStr = in.substring( VARS_BEGIN.length(),in.length() - VARS_END.length()).trim();
 			ret = LocalVariableMap.deserialize(varStr);
 		}
 		else { //empty input symbol table
diff --git a/src/test/java/org/apache/sysds/test/component/paramserv/SerializationTest.java b/src/test/java/org/apache/sysds/test/component/paramserv/SerializationTest.java
index 665997a..bf47f19 100644
--- a/src/test/java/org/apache/sysds/test/component/paramserv/SerializationTest.java
+++ b/src/test/java/org/apache/sysds/test/component/paramserv/SerializationTest.java
@@ -19,8 +19,15 @@
 
 package org.apache.sysds.test.component.paramserv;
 
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.ObjectInput;
+import java.io.ObjectOutputStream;
+import java.io.ObjectInputStream;
 import java.util.Arrays;
+import java.util.Collection;
 
+import org.apache.sysds.runtime.DMLRuntimeException;
 import org.junit.Assert;
 import org.junit.Test;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -29,32 +36,80 @@ import org.apache.sysds.runtime.instructions.cp.IntObject;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.util.DataConverter;
 import org.apache.sysds.runtime.util.ProgramConverter;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
+@RunWith(Parameterized.class)
 public class SerializationTest {
+	private int _named;
+
+	@Parameterized.Parameters
+	public static Collection named() {
+		return Arrays.asList(new Object[][] {
+				{ 0 },
+				{ 1 }
+		});
+	}
+
+	public SerializationTest(Integer named) {
+		this._named = named;
+	}
 
 	@Test
-	public void serializeUnnamedListObject() {
+	public void serializeListObject() {
 		MatrixObject mo1 = generateDummyMatrix(10);
 		MatrixObject mo2 = generateDummyMatrix(20);
 		IntObject io = new IntObject(30);
-		ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
-		String serial = ProgramConverter.serializeDataObject("key", lo);
-		Object[] obj = ProgramConverter.parseDataObject(serial);
-		ListObject actualLO = (ListObject) obj[1];
-		MatrixObject actualMO1 = (MatrixObject) actualLO.slice(0);
-		MatrixObject actualMO2 = (MatrixObject) actualLO.slice(1);
-		IntObject actualIO = (IntObject) actualLO.slice(2);
-		Assert.assertArrayEquals(mo1.acquireRead().getDenseBlockValues(), actualMO1.acquireRead().getDenseBlockValues(), 0);
-		Assert.assertArrayEquals(mo2.acquireRead().getDenseBlockValues(), actualMO2.acquireRead().getDenseBlockValues(), 0);
-		Assert.assertEquals(io.getLongValue(), actualIO.getLongValue());
+		ListObject lot = new ListObject(Arrays.asList(mo2));
+		ListObject lo;
+
+		if (_named == 1)
+			 lo = new ListObject(Arrays.asList(mo1, lot, io), Arrays.asList("e1", "e2", "e3"));
+		else
+			lo = new ListObject(Arrays.asList(mo1, lot, io));
+
+		ListObject loDeserialized = null;
+
+		// serialize and back
+		try {
+			ByteArrayOutputStream bos = new ByteArrayOutputStream();
+			ObjectOutputStream out = new ObjectOutputStream(bos);
+			out.writeObject(lo);
+			out.flush();
+			byte[] loBytes = bos.toByteArray();
+
+			ByteArrayInputStream bis = new ByteArrayInputStream(loBytes);
+			ObjectInput in = new ObjectInputStream(bis);
+			loDeserialized = (ListObject) in.readObject();
+		}
+		catch(Exception e){
+			System.out.println("Error while serializing and deserializing to bytes: " + e);
+			assert(false);
+		}
+
+		MatrixObject mo1Deserialized = (MatrixObject) loDeserialized.getData(0);
+		ListObject lotDeserialized = (ListObject) loDeserialized.getData(1);
+		MatrixObject mo2Deserialized = (MatrixObject) lotDeserialized.getData(0);
+		IntObject ioDeserialized = (IntObject) loDeserialized.getData(2);
+
+		if (_named == 1)
+			Assert.assertEquals(lo.getNames(), loDeserialized.getNames());
+
+		Assert.assertArrayEquals(mo1.acquireRead().getDenseBlockValues(), mo1Deserialized.acquireRead().getDenseBlockValues(), 0);
+		Assert.assertArrayEquals(mo2.acquireRead().getDenseBlockValues(), mo2Deserialized.acquireRead().getDenseBlockValues(), 0);
+		Assert.assertEquals(io.getLongValue(), ioDeserialized.getLongValue());
 	}
 
 	@Test
-	public void serializeNamedListObject() {
+	public void serializeListObjectProgramConverter() {
 		MatrixObject mo1 = generateDummyMatrix(10);
 		MatrixObject mo2 = generateDummyMatrix(20);
 		IntObject io = new IntObject(30);
-		ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io), Arrays.asList("e1", "e2", "e3"));
+		ListObject lo;
+		if (_named == 1)
+			lo = new ListObject(Arrays.asList(mo1, mo2, io), Arrays.asList("e1", "e2", "e3"));
+		else
+			lo = new ListObject(Arrays.asList(mo1, mo2, io));
 
 		String serial = ProgramConverter.serializeDataObject("key", lo);
 		Object[] obj = ProgramConverter.parseDataObject(serial);
@@ -62,7 +117,10 @@ public class SerializationTest {
 		MatrixObject actualMO1 = (MatrixObject) actualLO.slice(0);
 		MatrixObject actualMO2 = (MatrixObject) actualLO.slice(1);
 		IntObject actualIO = (IntObject) actualLO.slice(2);
-		Assert.assertEquals(lo.getNames(), actualLO.getNames());
+
+		if (_named == 1)
+			Assert.assertEquals(lo.getNames(), actualLO.getNames());
+
 		Assert.assertArrayEquals(mo1.acquireRead().getDenseBlockValues(), actualMO1.acquireRead().getDenseBlockValues(), 0);
 		Assert.assertArrayEquals(mo2.acquireRead().getDenseBlockValues(), actualMO2.acquireRead().getDenseBlockValues(), 0);
 		Assert.assertEquals(io.getLongValue(), actualIO.getLongValue());
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
new file mode 100644
index 0000000..194df09
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.paramserv;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedParamservTest extends AutomatedTestBase {
+    private final static String TEST_DIR = "functions/federated/paramserv/";
+    private final static String TEST_NAME = "FederatedParamservTest";
+    private final static String TEST_CLASS_DIR = TEST_DIR + FederatedParamservTest.class.getSimpleName() + "/";
+    private final static int _blocksize = 1024;
+
+    private final String _networkType;
+    private final int _numFederatedWorkers;
+    private final int _examplesPerWorker;
+    private final int _epochs;
+    private final int _batch_size;
+    private final double _eta;
+    private final String _utype;
+    private final String _freq;
+
+    private Types.ExecMode _platformOld;
+
+    // parameters
+    @Parameterized.Parameters
+    public static Collection<Object[]> parameters() {
+        return Arrays.asList(new Object[][] {
+                //Network type, number of federated workers, examples per worker, batch size, epochs, learning rate, update type, update frequency
+                {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"},
+                {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"},
+                {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"},
+                {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
+                {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"},
+                {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"},
+                {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"},
+                {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
+                {"TwoNN", 5, 1000, 32, 2, 0.01, "BSP", "BATCH"},
+                {"TwoNN", 5, 1000, 32, 2, 0.01, "ASP", "BATCH"},
+                {"TwoNN", 5, 1000, 32, 2, 0.01, "BSP", "EPOCH"},
+                {"TwoNN", 5, 1000, 32, 2, 0.01, "ASP", "EPOCH"},
+                {"CNN", 5, 1000, 32, 2, 0.01, "BSP", "BATCH"},
+                {"CNN", 5, 1000, 32, 2, 0.01, "ASP", "BATCH"},
+                {"CNN", 5, 1000, 32, 2, 0.01, "BSP", "EPOCH"},
+                {"CNN", 5, 1000, 32, 2, 0.01, "ASP", "EPOCH"}
+        });
+    }
+
+    public FederatedParamservTest(String networkType, int numFederatedWorkers, int examplesPerWorker, int batch_size, int epochs, double eta, String utype, String freq) {
+        _networkType = networkType;
+        _numFederatedWorkers = numFederatedWorkers;
+        _examplesPerWorker = examplesPerWorker;
+        _batch_size = batch_size;
+        _epochs = epochs;
+        _eta = eta;
+        _utype = utype;
+        _freq = freq;
+    }
+
+    @Override
+    public void setUp() {
+        TestUtils.clearAssertionInformation();
+        addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+
+        _platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
+    }
+
+    @Override
+    public void tearDown() {
+
+        rtplatform = _platformOld;
+    }
+
+    @Test
+    public void federatedParamserv() {
+        // config
+        getAndLoadTestConfiguration(TEST_NAME);
+        String HOME = SCRIPT_DIR + TEST_DIR;
+        setOutputBuffering(true);
+
+        int C = 1, Hin = 28, Win = 28;
+        int numFeatures = C*Hin*Win;
+        int numLabels = 10;
+
+        // dml name
+        fullDMLScriptName = HOME + TEST_NAME + ".dml";
+        // generate program args
+        List<String> programArgsList = new ArrayList<>(Arrays.asList(
+                "-stats",
+                "-nvargs",
+                "examples_per_worker=" + _examplesPerWorker,
+                "num_features=" + numFeatures,
+                "num_labels=" + numLabels,
+                "epochs=" + _epochs,
+                "batch_size=" + _batch_size,
+                "eta=" + _eta,
+                "utype=" + _utype,
+                "freq=" + _freq,
+                "network_type=" + _networkType,
+                "channels=" + C,
+                "hin=" + Hin,
+                "win=" + Win
+        ));
+
+        // for each worker
+        List<Integer> ports = new ArrayList<>();
+        List<Thread> threads = new ArrayList<>();
+        for(int i = 0; i < _numFederatedWorkers; i++) {
+            // write row partitioned features to disk
+            writeInputMatrixWithMTD("X" + i, generateDummyMNISTFeatures(_examplesPerWorker, C, Hin, Win), false,
+                    new MatrixCharacteristics(_examplesPerWorker, numFeatures, _blocksize, _examplesPerWorker * numFeatures));
+            // write row partitioned labels to disk
+            writeInputMatrixWithMTD("y" + i, generateDummyMNISTLabels(_examplesPerWorker, numLabels), false,
+                    new MatrixCharacteristics(_examplesPerWorker, numLabels, _blocksize, _examplesPerWorker * numLabels));
+
+            // start worker
+            ports.add(getRandomAvailablePort());
+            threads.add(startLocalFedWorkerThread(ports.get(i)));
+
+            // add worker to program args
+            programArgsList.add("X" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("X" + i)));
+            programArgsList.add("y" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("y" + i)));
+        }
+
+        programArgs = programArgsList.toArray(new String[0]);
+        // ByteArrayOutputStream stdout =
+        runTest(null);
+        // System.out.print(stdout.toString());
+
+        // cleanup
+        for(int i = 0; i < _numFederatedWorkers; i++) {
+            TestUtils.shutdownThreads(threads.get(i));
+        }
+    }
+
+    /**
+     * Generates an feature matrix that has the same format as the MNIST dataset,
+     * but is completely random and normalized
+     *
+     *  @param numExamples Number of examples to generate
+     *  @param C Channels in the input data
+     *  @param Hin Height in Pixels of the input data
+     *  @param Win Width in Pixels of the input data
+     *  @return a dummy MNIST feature matrix
+     */
+    private double[][] generateDummyMNISTFeatures(int numExamples, int C, int Hin, int Win) {
+        // Seed -1 takes the time in milliseconds as a seed
+        // Sparsity 1 means no sparsity
+        return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1);
+    }
+
+    /**
+     * Generates an label matrix that has the same format as the MNIST dataset, but is completely random and consists
+     * of one hot encoded vectors as rows
+     *
+     *  @param numExamples Number of examples to generate
+     *  @param numLabels Number of labels to generate
+     *  @return a dummy MNIST lable matrix
+     */
+    private double[][] generateDummyMNISTLabels(int numExamples, int numLabels) {
+        // Seed -1 takes the time in milliseconds as a seed
+        // Sparsity 1 means no sparsity
+        return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1);
+    }
+}
diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml b/src/test/scripts/functions/federated/paramserv/CNN.dml
new file mode 100644
index 0000000..55d05dc
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/CNN.dml
@@ -0,0 +1,474 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * This file implements all needed functions to evaluate a convolutional neural network of the "LeNet" architecture
+ * on different execution schemes and with different inputs, for example a federated input matrix.
+ */
+
+# Imports
+source("scripts/nn/layers/affine.dml") as affine
+source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("scripts/nn/layers/dropout.dml") as dropout
+source("scripts/nn/layers/l2_reg.dml") as l2_reg
+source("scripts/nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("scripts/nn/layers/relu.dml") as relu
+source("scripts/nn/layers/softmax.dml") as softmax
+source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+/*
+ * Trains a convolutional net using the "LeNet" architectur single threaded the conventional way.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.  The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix, of shape (N, C*Hin*Win)
+ *  - y: Target matrix, of shape (N, K)
+ *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win)
+ *  - y_val: Target validation matrix, of shape (N, K)
+ *  - C: Number of input channels (dimensionality of input depth)
+ *  - Hin: Input height
+ *  - Win: Input width
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ *       - b1: 1st layer biases vector, of shape (F1, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ *       - b2: 2nd layer biases vector, of shape (F2, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3)
+ *       - b3: 3rd layer biases vector, of shape (1, N3)
+ *       - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ *       - b4: 4th layer biases vector, of shape (1, K)
+ */
+train = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int C, int Hin, int Win, int epochs, int batch_size, double learning_rate)
+    return (list[unknown] model_trained) {
+
+  N = nrow(X)
+  K = ncol(y)
+
+  # Create network:
+  ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf)  # inputs: (N, C*Hin*Win)
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf)  # inputs: (N, F1*(Hin/2)*(Win/2))
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3)  # inputs: (N, F2*(Hin/2/2)*(Win/2/2))
+  [W4, b4] = affine::init(N3, K)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  learning_rate = learning_rate  # learning rate
+  mu = 0.9  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+  # Regularization
+  lambda = 5e-04
+
+  # Create the hyper parameter list
+  hyperparams = list(learning_rate=learning_rate, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+  # Calculate iterations
+  iters = ceil(N / batch_size)
+  print_interval = floor(iters / 25)
+
+  print("[+] Starting optimization")
+  print("[+]  Learning rate: " + learning_rate)
+  print("[+]  Batch size: " + batch_size)
+  print("[+]  Iterations per epoch: " + iters + "\n")
+
+  for (e in 1:epochs) {
+    print("[+] Starting epoch: " + e)
+    print("|")
+    for(i in 1:iters) {
+      # Create the model list
+      model_list = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
+
+      # Get next batch
+      beg = ((i-1) * batch_size) %% N + 1
+      end = min(N, beg + batch_size - 1)
+      X_batch = X[beg:end,]
+      y_batch = y[beg:end,]
+
+      gradients_list = gradients(model_list, hyperparams, X_batch, y_batch)
+      model_updated = aggregation(model_list, hyperparams, gradients_list)
+
+      W1 = as.matrix(model_updated[1])
+      W2 = as.matrix(model_updated[2])
+      W3 = as.matrix(model_updated[3])
+      W4 = as.matrix(model_updated[4])
+      b1 = as.matrix(model_updated[5])
+      b2 = as.matrix(model_updated[6])
+      b3 = as.matrix(model_updated[7])
+      b4 = as.matrix(model_updated[8])
+      vW1 = as.matrix(model_updated[9])
+      vW2 = as.matrix(model_updated[10])
+      vW3 = as.matrix(model_updated[11])
+      vW4 = as.matrix(model_updated[12])
+      vb1 = as.matrix(model_updated[13])
+      vb2 = as.matrix(model_updated[14])
+      vb3 = as.matrix(model_updated[15])
+      vb4 = as.matrix(model_updated[16])
+      if((i %% print_interval) == 0) {
+        print("█")
+      }
+    }
+    print("|")
+  }
+
+  model_trained = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
+}
+
+/*
+ * Trains a convolutional net using the "LeNet" architecture using a parameter server with specified properties.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.  The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix, of shape (N, C*Hin*Win)
+ *  - Y: Target matrix, of shape (N, K)
+ *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win)
+ *  - Y_val: Target validation matrix, of shape (N, K)
+ *  - C: Number of input channels (dimensionality of input depth)
+ *  - Hin: Input height
+ *  - Win: Input width
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *  - workers: Number of workers to create
+ *  - utype: parameter server framework to use
+ *  - scheme: update schema
+ *  - mode: local or distributed
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ *       - b1: 1st layer biases vector, of shape (F1, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ *       - b2: 2nd layer biases vector, of shape (F2, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3)
+ *       - b3: 3rd layer biases vector, of shape (1, N3)
+ *       - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ *       - b4: 4th layer biases vector, of shape (1, K)
+ */
+train_paramserv = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int C, int Hin, int Win, int epochs, int workers,
+                 string utype, string freq, int batch_size, string scheme, string mode, double learning_rate)
+    return (list[unknown] model_trained) {
+
+  N = nrow(X)
+  K = ncol(y)
+
+  # Create network:
+  ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf)  # inputs: (N, C*Hin*Win)
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf)  # inputs: (N, F1*(Hin/2)*(Win/2))
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3)  # inputs: (N, F2*(Hin/2/2)*(Win/2/2))
+  [W4, b4] = affine::init(N3, K)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  learning_rate = learning_rate  # learning rate
+  mu = 0.9  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+  # Regularization
+  lambda = 5e-04
+  # Create the model list
+  model_list = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
+  # Create the hyper parameter list
+  params = list(learning_rate=learning_rate, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+
+  # Use paramserv function
+  model_trained = paramserv(model=model_list, features=X, labels=y, val_features=X_val, val_labels=y_val, upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients", agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation", mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE")
+}
+
+/*
+ * Computes the class probability predictions of a convolutional
+ * net using the "LeNet" architecture.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.
+ *
+ * Inputs:
+ *  - X: Input data matrix, of shape (N, C*Hin*Win)
+ *  - C: Number of input channels (dimensionality of input depth)
+ *  - Hin: Input height
+ *  - Win: Input width
+ *  - batch_size: Batch size
+ *  - model: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ *       - b1: 1st layer biases vector, of shape (F1, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ *       - b2: 2nd layer biases vector, of shape (F2, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3)
+ *       - b3: 3rd layer biases vector, of shape (1, N3)
+ *       - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ *       - b4: 4th layer biases vector, of shape (1, K)
+ *
+ * Outputs:
+ *  - probs: Class probabilities, of shape (N, K)
+ */
+predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size, list[unknown] model)
+    return (matrix[double] probs) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  W4 = as.matrix(model[4])
+  b1 = as.matrix(model[5])
+  b2 = as.matrix(model[6])
+  b3 = as.matrix(model[7])
+  b4 = as.matrix(model[8])
+  N = nrow(X)
+
+  # Network:
+  ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+  F1 = nrow(W1)  # num conv filters in conv1
+  F2 = nrow(W2)  # num conv filters in conv2
+  N3 = ncol(W3)  # num nodes in affine3
+  K = ncol(W4)  # num nodes in affine4, equal to number of target dimensions (num classes)
+
+  # Compute predictions over mini-batches
+  probs = matrix(0, rows=N, cols=K)
+  iters = ceil(N / batch_size)
+  parfor(i in 1:iters, check=0) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1
+    end = min(N, beg + batch_size - 1)
+    X_batch = X[beg:end,]
+
+    # Compute forward pass
+    ## layer 1: conv1 -> relu1 -> pool1
+    [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, Hf, Wf, stride, stride,
+                                              pad, pad)
+    outr1 = relu::forward(outc1)
+    [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0)
+    ## layer 2: conv2 -> relu2 -> pool2
+    [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+                                              stride, stride, pad, pad)
+    outr2 = relu::forward(outc2)
+    [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2, 2, 2, 2, 0, 0)
+    ## layer 3:  affine3 -> relu3
+    outa3 = affine::forward(outp2, W3, b3)
+    outr3 = relu::forward(outa3)
+    ## layer 4:  affine4 -> softmax
+    outa4 = affine::forward(outr3, W4, b4)
+    probs_batch = softmax::forward(outa4)
+
+    # Store predictions
+    probs[beg:end,] = probs_batch
+  }
+}
+
+/*
+ * Evaluates a convolutional net using the "LeNet" architecture.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples.  The targets, y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ *  - probs: Class probabilities, of shape (N, K)
+ *  - y: Target matrix, of shape (N, K)
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1)
+ *  - accuracy: Scalar accuracy, of shape (1)
+ */
+eval = function(matrix[double] probs, matrix[double] y)
+    return (double loss, double accuracy) {
+
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+  accuracy = mean(correct_pred)
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+                     list[unknown] hyperparams,
+                     matrix[double] features,
+                     matrix[double] labels)
+          return (list[unknown] gradients) {
+
+  C = as.integer(as.scalar(hyperparams["C"]))
+  Hin = as.integer(as.scalar(hyperparams["Hin"]))
+  Win = as.integer(as.scalar(hyperparams["Win"]))
+  Hf = as.integer(as.scalar(hyperparams["Hf"]))
+  Wf = as.integer(as.scalar(hyperparams["Wf"]))
+  stride = as.integer(as.scalar(hyperparams["stride"]))
+  pad = as.integer(as.scalar(hyperparams["pad"]))
+  lambda = as.double(as.scalar(hyperparams["lambda"]))
+  F1 = as.integer(as.scalar(hyperparams["F1"]))
+  F2 = as.integer(as.scalar(hyperparams["F2"]))
+  N3 = as.integer(as.scalar(hyperparams["N3"]))
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  W4 = as.matrix(model[4])
+  b1 = as.matrix(model[5])
+  b2 = as.matrix(model[6])
+  b3 = as.matrix(model[7])
+  b4 = as.matrix(model[8])
+
+  # Compute forward pass
+  ## layer 1: conv1 -> relu1 -> pool1
+  [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, Wf,
+                                              stride, stride, pad, pad)
+  outr1 = relu::forward(outc1)
+  [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0)
+  ## layer 2: conv2 -> relu2 -> pool2
+  [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+                                            stride, stride, pad, pad)
+  outr2 = relu::forward(outc2)
+  [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2, 2, 2, 2, 0, 0)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  outa3 = affine::forward(outp2, W3, b3)
+  outr3 = relu::forward(outa3)
+  [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+  ## layer 4:  affine4 -> softmax
+  outa4 = affine::forward(outd3, W4, b4)
+  probs = softmax::forward(outa4)
+
+  # Compute loss & accuracy for training data
+  loss = cross_entropy_loss::forward(probs, labels)
+  accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
+  print("[+] Completed forward pass on batch: train loss: " + loss + ", train accuracy: " + accuracy)
+
+  # Compute data backward pass
+  ## loss
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  ## layer 4:  affine4 -> softmax
+  douta4 = softmax::backward(dprobs, outa4)
+  [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+  douta3 = relu::backward(doutr3, outa3)
+  [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+  ## layer 2: conv2 -> relu2 -> pool2
+  doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, Woutc2, 2, 2, 2, 2, 0, 0)
+  doutc2 = relu::backward(doutr2, outc2)
+  [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, F1,
+                                        Houtp1, Woutp1, Hf, Wf, stride, stride, pad, pad)
+  ## layer 1: conv1 -> relu1 -> pool1
+  doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0)
+  doutc1 = relu::backward(doutr1, outc1)
+  [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, W1, b1, C, Hin, Win,
+                                          Hf, Wf, stride, stride, pad, pad)
+
+  # Compute regularization backward pass
+  dW1_reg = l2_reg::backward(W1, lambda)
+  dW2_reg = l2_reg::backward(W2, lambda)
+  dW3_reg = l2_reg::backward(W3, lambda)
+  dW4_reg = l2_reg::backward(W4, lambda)
+  dW1 = dW1 + dW1_reg
+  dW2 = dW2 + dW2_reg
+  dW3 = dW3 + dW3_reg
+  dW4 = dW4 + dW4_reg
+
+  gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
+    return (list[unknown] model_result) {
+
+   W1 = as.matrix(model[1])
+   W2 = as.matrix(model[2])
+   W3 = as.matrix(model[3])
+   W4 = as.matrix(model[4])
+   b1 = as.matrix(model[5])
+   b2 = as.matrix(model[6])
+   b3 = as.matrix(model[7])
+   b4 = as.matrix(model[8])
+   dW1 = as.matrix(gradients[1])
+   dW2 = as.matrix(gradients[2])
+   dW3 = as.matrix(gradients[3])
+   dW4 = as.matrix(gradients[4])
+   db1 = as.matrix(gradients[5])
+   db2 = as.matrix(gradients[6])
+   db3 = as.matrix(gradients[7])
+   db4 = as.matrix(gradients[8])
+   vW1 = as.matrix(model[9])
+   vW2 = as.matrix(model[10])
+   vW3 = as.matrix(model[11])
+   vW4 = as.matrix(model[12])
+   vb1 = as.matrix(model[13])
+   vb2 = as.matrix(model[14])
+   vb3 = as.matrix(model[15])
+   vb4 = as.matrix(model[16])
+   learning_rate = as.double(as.scalar(hyperparams["learning_rate"]))
+   mu = as.double(as.scalar(hyperparams["mu"]))
+
+   # Optimize with SGD w/ Nesterov momentum
+   [W1, vW1] = sgd_nesterov::update(W1, dW1, learning_rate, mu, vW1)
+   [b1, vb1] = sgd_nesterov::update(b1, db1, learning_rate, mu, vb1)
+   [W2, vW2] = sgd_nesterov::update(W2, dW2, learning_rate, mu, vW2)
+   [b2, vb2] = sgd_nesterov::update(b2, db2, learning_rate, mu, vb2)
+   [W3, vW3] = sgd_nesterov::update(W3, dW3, learning_rate, mu, vW3)
+   [b3, vb3] = sgd_nesterov::update(b3, db3, learning_rate, mu, vb3)
+   [W4, vW4] = sgd_nesterov::update(W4, dW4, learning_rate, mu, vW4)
+   [b4, vb4] = sgd_nesterov::update(b4, db4, learning_rate, mu, vb4)
+
+   model_result = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
+}
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
new file mode 100644
index 0000000..16c72c4
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
@@ -0,0 +1,57 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/federated/paramserv/TwoNN.dml") as TwoNN
+source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN
+
+# create federated input matrices
+features = federated(addresses=list($X0, $X1),
+    ranges=list(list(0, 0), list($examples_per_worker, $num_features),
+                list($examples_per_worker, 0), list($examples_per_worker * 2, $num_features)))
+
+labels = federated(addresses=list($y0, $y1),
+    ranges=list(list(0, 0), list($examples_per_worker, $num_labels),
+                list($examples_per_worker, 0), list($examples_per_worker * 2, $num_labels)))
+
+epochs = $epochs
+batch_size = $batch_size
+learning_rate = $eta
+utype = $utype
+freq = $freq
+network_type = $network_type
+
+# currently ignored parameters
+workers = 1
+scheme = "DISJOINT_CONTIGUOUS"
+paramserv_mode = "LOCAL"
+
+# config for the cnn
+channels = $channels
+hin = $hin
+win = $win
+
+if(network_type == "TwoNN") {
+  model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), epochs, workers, utype, freq, batch_size, scheme, paramserv_mode, learning_rate)
+}
+else {
+  model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), channels, hin, win, epochs, workers, utype, freq, batch_size, scheme, paramserv_mode, learning_rate)
+}
+print(toString(model))
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
new file mode 100644
index 0000000..3bcfe84
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
@@ -0,0 +1,299 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * This file implements all needed functions to evaluate a simple feed forward neural network
+ * on different execution schemes and with different inputs, for example a federated input matrix.
+ */
+
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd.dml") as sgd
+
+/*
+ * Trains a simple feed forward neural network with two hidden layers single threaded the conventional way.
+ *
+ * The input matrix has one example per row (N) and D features.
+ * The targets, y, have K classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix of shape (N, D)
+ *  - y: Target matrix of shape (N, K)
+ *  - X_val: Input validation data matrix of shape (N_val, D)
+ *  - y_val: Targed validation matrix of shape (N_val, K)
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ *       - b1: 1st layer biases vector, of shape (200, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ *       - b2: 2nd layer biases vector, of shape (200, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ *       - b3: 3rd layer biases vector, of shape (K, 1)
+ */
+train = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int epochs, int batch_size, double learning_rate)
+    return (list[unknown] model_trained) {
+
+  N = nrow(X)  # num examples
+  D = ncol(X)  # num features
+  K = ncol(y)  # num classes
+
+  # Create the network:
+  ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+  [W1, b1] = affine::init(D, 200)
+  [W2, b2] = affine::init(200, 200)
+  [W3, b3] = affine::init(200, K)
+  W3 = W3 / sqrt(2)  # different initialization, since being fed into softmax, instead of relu
+
+  # Create the hyper parameter list
+  hyperparams = list(learning_rate=learning_rate)
+  # Calculate iterations
+  iters = ceil(N / batch_size)
+  print_interval = floor(iters / 25)
+
+  print("[+] Starting optimization")
+  print("[+]  Learning rate: " + learning_rate)
+  print("[+]  Batch size: " + batch_size)
+  print("[+]  Iterations per epoch: " + iters + "\n")
+
+  for (e in 1:epochs) {
+    print("[+] Starting epoch: " + e)
+    print("|")
+    for(i in 1:iters) {
+      # Create the model list
+      model_list = list(W1, W2, W3, b1, b2, b3)
+
+      # Get next batch
+      beg = ((i-1) * batch_size) %% N + 1
+      end = min(N, beg + batch_size - 1)
+      X_batch = X[beg:end,]
+      y_batch = y[beg:end,]
+
+      gradients_list = gradients(model_list, hyperparams, X_batch, y_batch)
+      model_updated = aggregation(model_list, hyperparams, gradients_list)
+
+      W1 = as.matrix(model_updated[1])
+      W2 = as.matrix(model_updated[2])
+      W3 = as.matrix(model_updated[3])
+      b1 = as.matrix(model_updated[4])
+      b2 = as.matrix(model_updated[5])
+      b3 = as.matrix(model_updated[6])
+
+      if((i %% print_interval) == 0) {
+        print("█")
+      }
+    }
+    print("|")
+  }
+
+  model_trained = list(W1, W2, W3, b1, b2, b3)
+}
+
+/*
+ * Trains a simple feed forward neural network with two hidden layers
+ * using a parameter server with specified properties.
+ *
+ * The input matrix has one example per row (N) and D features.
+ * The targets, y, have K classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix of shape (N, D)
+ *  - y: Target matrix of shape (N, K)
+ *  - X_val: Input validation data matrix of shape (N_val, D)
+ *  - y_val: Targed validation matrix of shape (N_val, K)
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *  - workers: Number of workers to create
+ *  - utype: parameter server framework to use
+ *  - scheme: update schema
+ *  - mode: local or distributed
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ *       - b1: 1st layer biases vector, of shape (200, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ *       - b2: 2nd layer biases vector, of shape (200, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ *       - b3: 3rd layer biases vector, of shape (K, 1)
+ */
+train_paramserv = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int epochs, int workers,
+                 string utype, string freq, int batch_size, string scheme, string mode, double learning_rate)
+    return (list[unknown] model_trained) {
+
+  N = nrow(X)  # num examples
+  D = ncol(X)  # num features
+  K = ncol(y)  # num classes
+
+  # Create the network:
+  ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+  [W1, b1] = affine::init(D, 200)
+  [W2, b2] = affine::init(200, 200)
+  [W3, b3] = affine::init(200, K)
+
+  # Create the model list
+  model_list = list(W1, W2, W3, b1, b2, b3)
+  # Create the hyper parameter list
+  params = list(learning_rate=learning_rate)
+  # Use paramserv function
+  model_trained = paramserv(model=model_list, features=X, labels=y, val_features=X_val, val_labels=y_val, upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients", agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation", mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE")
+}
+
+/*
+ * Computes the class probability predictions of a simple feed forward neural network.
+ *
+ * Inputs:
+ *  - X: The input data matrix of shape (N, D)
+ *  - model: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ *       - b1: 1st layer biases vector, of shape (200, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ *       - b2: 2nd layer biases vector, of shape (200, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ *       - b3: 3rd layer biases vector, of shape (K, 1)
+ *
+ * Outputs:
+ *  - probs: Class probabilities, of shape (N, K)
+ */
+predict = function(matrix[double] X,
+                   list[unknown] model)
+    return (matrix[double] probs) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  b1 = as.matrix(model[4])
+  b2 = as.matrix(model[5])
+  b3 = as.matrix(model[6])
+
+  out1relu = relu::forward(affine::forward(X, W1, b1))
+  out2relu = relu::forward(affine::forward(out1relu, W2, b2))
+  probs = softmax::forward(affine::forward(out2relu, W3, b3))
+}
+
+/*
+ * Evaluates a simple feed forward neural network.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples.  The targets, y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ *  - probs: Class probabilities, of shape (N, K).
+ *  - y: Target matrix, of shape (N, K).
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1).
+ *  - accuracy: Scalar accuracy, of shape (1).
+ */
+eval = function(matrix[double] probs, matrix[double] y)
+    return (double loss, double accuracy) {
+
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+  accuracy = mean(correct_pred)
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+                     list[unknown] hyperparams,
+                     matrix[double] features,
+                     matrix[double] labels)
+    return (list[unknown] gradients) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  b1 = as.matrix(model[4])
+  b2 = as.matrix(model[5])
+  b3 = as.matrix(model[6])
+
+  # Compute forward pass
+  ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+  out1 = affine::forward(features, W1, b1)
+  out1relu = relu::forward(out1)
+  out2 = affine::forward(out1relu, W2, b2)
+  out2relu = relu::forward(out2)
+  out3 = affine::forward(out2relu, W3, b3)
+  probs = softmax::forward(out3)
+
+  # Compute loss & accuracy for training data
+  loss = cross_entropy_loss::forward(probs, labels)
+  accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
+  print("[+] Completed forward pass on batch: train loss: " + loss + ", train accuracy: " + accuracy)
+
+  # Compute data backward pass
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  dout3 = softmax::backward(dprobs, out3)
+  [dout2relu, dW3, db3] = affine::backward(dout3, out2relu, W3, b3)
+  dout2 = relu::backward(dout2relu, out2)
+  [dout1relu, dW2, db2] = affine::backward(dout2, out1relu, W2, b2)
+  dout1 = relu::backward(dout1relu, out1)
+  [dfeatures, dW1, db1] = affine::backward(dout1, features, W1, b1)
+
+  gradients = list(dW1, dW2, dW3, db1, db2, db3)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
+    return (list[unknown] model_result) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  b1 = as.matrix(model[4])
+  b2 = as.matrix(model[5])
+  b3 = as.matrix(model[6])
+  dW1 = as.matrix(gradients[1])
+  dW2 = as.matrix(gradients[2])
+  dW3 = as.matrix(gradients[3])
+  db1 = as.matrix(gradients[4])
+  db2 = as.matrix(gradients[5])
+  db3 = as.matrix(gradients[6])
+  learning_rate = as.double(as.scalar(hyperparams["learning_rate"]))
+
+  # Optimize with SGD
+  W3 = sgd::update(W3, dW3, learning_rate)
+  b3 = sgd::update(b3, db3, learning_rate)
+  W2 = sgd::update(W2, dW2, learning_rate)
+  b2 = sgd::update(b2, db2, learning_rate)
+  W1 = sgd::update(W1, dW1, learning_rate)
+  b1 = sgd::update(b1, db1, learning_rate)
+
+  model_result = list(W1, W2, W3, b1, b2, b3)
+}
\ No newline at end of file