You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/05/29 06:21:24 UTC

[2/2] systemml git commit: [SYSTEMML-2085] Initial version of local backend for paramserv builtin

[SYSTEMML-2085] Initial version of local backend for paramserv builtin

Closes #771.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/97018d4e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/97018d4e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/97018d4e

Branch: refs/heads/master
Commit: 97018d4e688ba7eeaaa4567ca1e174a3c5525468
Parents: c7a9e01
Author: EdgarLGB <gu...@atos.net>
Authored: Mon May 28 23:17:18 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Mon May 28 23:17:20 2018 -0700

----------------------------------------------------------------------
 .../ParameterizedBuiltinFunctionExpression.java |   6 +-
 .../java/org/apache/sysml/parser/Statement.java |   1 +
 .../context/ExecutionContext.java               |  13 +-
 .../controlprogram/paramserv/LocalPSWorker.java |  97 +++++
 .../paramserv/LocalParamServer.java             |  59 +++
 .../controlprogram/paramserv/PSWorker.java      | 131 +++++++
 .../controlprogram/paramserv/ParamServer.java   | 232 +++++++++++
 .../paramserv/ParamservUtils.java               |  97 +++++
 .../runtime/instructions/cp/CPOperand.java      |   2 +-
 .../runtime/instructions/cp/ListObject.java     |  14 +
 .../cp/MatrixIndexingCPInstruction.java         |   4 +-
 .../cp/ParamservBuiltinCPInstruction.java       | 257 ++++++++++++-
 .../test/integration/AutomatedTestBase.java     |  18 +-
 .../functions/paramserv/ParamservFuncTest.java  |  29 +-
 .../paramserv/mnist_lenet_paramserv.dml         | 383 +++++++++++++++++++
 .../mnist_lenet_paramserv_minimum_version.dml   | 377 ++++++++++++++++++
 .../functions/paramserv/paramserv-all-args.dml  |   4 +-
 .../functions/paramserv/paramserv-ipa-test.dml  |  47 ---
 .../paramserv/paramserv-minimum-version.dml     |  52 +++
 .../functions/paramserv/paramserv-miss-args.dml |   4 +-
 .../functions/paramserv/paramserv-nn-test.dml   |  52 +++
 .../paramserv-without-optional-args.dml         |   4 +-
 22 files changed, 1805 insertions(+), 78 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
index 3d74f8d..99aec78 100644
--- a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
@@ -341,12 +341,12 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
 				.collect(Collectors.toSet());
 		checkStringParam(false, fname, Statement.PS_UPDATE_TYPE, utypes, conditional);
 		Set<String> frequencies = Arrays.stream(Statement.PSFrequency.values()).map(Enum::name).collect(Collectors.toSet());
-		checkStringParam(false, fname, Statement.PS_FREQUENCY, frequencies, conditional);
+		checkStringParam(true, fname, Statement.PS_FREQUENCY, frequencies, conditional);
 		checkDataValueType(false, fname, Statement.PS_EPOCHS, DataType.SCALAR, ValueType.INT, conditional);
 		checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, DataType.SCALAR, ValueType.INT, conditional);
-		checkDataValueType(false, fname, Statement.PS_PARALLELISM, DataType.SCALAR, ValueType.INT, conditional);
+		checkDataValueType(true, fname, Statement.PS_PARALLELISM, DataType.SCALAR, ValueType.INT, conditional);
 		Set<String> schemes = Arrays.stream(Statement.PSScheme.values()).map(Enum::name).collect(Collectors.toSet());
-		checkStringParam(false, fname, Statement.PS_SCHEME, schemes, conditional);
+		checkStringParam(true, fname, Statement.PS_SCHEME, schemes, conditional);
 		checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, DataType.LIST, ValueType.UNKNOWN, conditional);
 		Set<String> checkpointings = Arrays.stream(Statement.PSCheckpointing.values()).map(Enum::name).collect(Collectors.toSet());
 		checkStringParam(true, fname, Statement.PS_CHECKPOINTING, checkpointings, conditional);

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/parser/Statement.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Statement.java b/src/main/java/org/apache/sysml/parser/Statement.java
index 4853a47..1987d31 100644
--- a/src/main/java/org/apache/sysml/parser/Statement.java
+++ b/src/main/java/org/apache/sysml/parser/Statement.java
@@ -71,6 +71,7 @@ public abstract class Statement implements ParseInfo
 	public static final String PS_UPDATE_FUN = "upd";
 	public static final String PS_AGGREGATION_FUN = "agg";
 	public static final String PS_MODE = "mode";
+	public static final String PS_GRADIENTS = "gradients";
 	public enum PSModeType {
 		LOCAL, REMOTE_SPARK
 	}

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
index 67b2a83..6807848 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
@@ -42,6 +42,7 @@ import org.apache.sysml.runtime.instructions.cp.CPInstruction;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
 import org.apache.sysml.runtime.instructions.cp.ScalarObject;
 import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
@@ -443,7 +444,17 @@ public class ExecutionContext {
 	public void setScalarOutput(String varName, ScalarObject so) {
 		setVariable(varName, so);
 	}
-	
+
+	public ListObject getListObject(String name) {
+		Data dat = getVariable(name);
+		//error handling if non existing or no list
+		if (dat == null)
+			throw new DMLRuntimeException("Variable '" + name + "' does not exist in the symbol table.");
+		if (!(dat instanceof ListObject))
+			throw new DMLRuntimeException("Variable '" + name + "' is not a list.");
+		return (ListObject) dat;
+	}
+
 	public void releaseMatrixOutputForGPUInstruction(String varName) {
 		MatrixObject mo = getMatrixObject(varName);
 		if(mo.getGPUObject(getGPUContext(0)) == null || !mo.getGPUObject(getGPUContext(0)).isAllocated()) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
new file mode 100644
index 0000000..181b866
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public class LocalPSWorker extends PSWorker implements Runnable {
+
+	protected static final Log LOG = LogFactory.getLog(LocalPSWorker.class.getName());
+
+	public LocalPSWorker(long workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
+			ListObject hyperParams, ExecutionContext ec, ParamServer ps) {
+		super(workerID, updFunc, freq, epochs, batchSize, hyperParams, ec, ps);
+	}
+
+	@Override
+	public void run() {
+
+		long dataSize = _features.getNumRows();
+
+		for (int i = 0; i < _epochs; i++) {
+			int totalIter = (int) Math.ceil(dataSize / _batchSize);
+			for (int j = 0; j < totalIter; j++) {
+				// Pull the global parameters from ps
+				// Need to copy the global parameter
+				ListObject globalParams = ParamservUtils.copyList((ListObject) _ps.pull(_workerID));
+				if (LOG.isDebugEnabled()) {
+					LOG.debug(String.format(
+							"Local worker_%d: Successfully pull the global parameters [size:%d kb] from ps.", _workerID,
+							globalParams.getDataSize() / 1024));
+				}
+				_ec.setVariable(Statement.PS_MODEL, globalParams);
+
+				long begin = j * _batchSize + 1;
+				long end = Math.min(begin + _batchSize, dataSize);
+
+				// Get batch features and labels
+				MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, begin, end);
+				MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, begin, end);
+				_ec.setVariable(Statement.PS_FEATURES, bFeatures);
+				_ec.setVariable(Statement.PS_LABELS, bLabels);
+
+				if (LOG.isDebugEnabled()) {
+					LOG.debug(String.format(
+							"Local worker_%d: Got batch data [size:%d kb] of index from %d to %d. [Epoch:%d  Total epoch:%d  Iteration:%d  Total iteration:%d]",
+							_workerID, bFeatures.getDataSize() / 1024 + bLabels.getDataSize() / 1024, begin, end, i + 1,
+							_epochs, j + 1, totalIter));
+				}
+
+				// Invoke the update function
+				_inst.processInstruction(_ec);
+
+				// Get the gradients
+				ListObject gradients = (ListObject) _ec.getVariable(_outputs.get(0).getName());
+
+				// Push the gradients to ps
+				_ps.push(_workerID, gradients);
+				if (LOG.isDebugEnabled()) {
+					LOG.debug(String.format("Local worker_%d: Successfully push the gradients [size:%d kb] to ps.",
+							_workerID, gradients.getDataSize() / 1024));
+				}
+
+				ParamservUtils.cleanupListObject(_ec, globalParams);
+				ParamservUtils.cleanupData(bFeatures);
+				ParamservUtils.cleanupData(bLabels);
+			}
+			if (LOG.isDebugEnabled()) {
+				LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1));
+			}
+		}
+		if (LOG.isDebugEnabled()) {
+			LOG.debug(String.format("Local worker_%d: Job finished.", _workerID));
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
new file mode 100644
index 0000000..d060a91
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public class LocalParamServer extends ParamServer {
+
+	public LocalParamServer(ListObject model, String aggFunc, Statement.PSFrequency freq,
+			Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum,
+			ListObject hyperParams) {
+		super(model, aggFunc, freq, updateType, ec, workerNum, hyperParams);
+	}
+
+	@Override
+	public void push(long workerID, ListObject gradients) {
+		synchronized (_lock) {
+			_queue.add(new Gradient(workerID, gradients));
+			_lock.notifyAll();
+		}
+	}
+
+	@Override
+	public Data pull(long workerID) {
+		synchronized (_lock) {
+			while (getPulledState((int) workerID)) {
+				try {
+					_lock.wait();
+				} catch (InterruptedException e) {
+					throw new DMLRuntimeException(
+							String.format("Local worker_%d: failed to pull the global parameters.", workerID), e);
+				}
+			}
+			setPulledState((int) workerID, true);
+		}
+		return getResult();
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
new file mode 100644
index 0000000..9ace823
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import java.util.ArrayList;
+import java.util.stream.Collectors;
+
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.DataIdentifier;
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+@SuppressWarnings("unused")
+public abstract class PSWorker {
+
+	long _workerID = -1;
+	int _epochs;
+	long _batchSize;
+	MatrixObject _features;
+	MatrixObject _labels;
+	ExecutionContext _ec;
+	ParamServer _ps;
+	private String _updFunc;
+	private Statement.PSFrequency _freq;
+	private MatrixObject _valFeatures;
+	private MatrixObject _valLabels;
+
+	ArrayList<DataIdentifier> _outputs;
+	FunctionCallCPInstruction _inst;
+
+	public PSWorker(long workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
+			ListObject hyperParams, ExecutionContext ec, ParamServer ps) {
+		this._workerID = workerID;
+		this._updFunc = updFunc;
+		this._freq = freq;
+		this._epochs = epochs;
+		this._batchSize = batchSize;
+		this._ec = ExecutionContextFactory.createContext(ec.getProgram());
+		if (hyperParams != null) {
+			this._ec.setVariable(Statement.PS_HYPER_PARAMS, hyperParams);
+		}
+		this._ps = ps;
+
+		// Get the update function
+		String[] keys = DMLProgram.splitFunctionKey(updFunc);
+		String _funcName = keys[0];
+		String _funcNS = null;
+		if (keys.length == 2) {
+			_funcNS = keys[0];
+			_funcName = keys[1];
+		}
+		FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(_funcNS, _funcName);
+		ArrayList<DataIdentifier> _inputs = func.getInputParams();
+		_outputs = func.getOutputParams();
+		CPOperand[] _boundInputs = _inputs.stream()
+				.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+				.toArray(CPOperand[]::new);
+		ArrayList<String> _inputNames = _inputs.stream().map(DataIdentifier::getName)
+				.collect(Collectors.toCollection(ArrayList::new));
+		ArrayList<String> _outputNames = _outputs.stream().map(DataIdentifier::getName)
+				.collect(Collectors.toCollection(ArrayList::new));
+		_inst = new FunctionCallCPInstruction(_funcNS, _funcName, _boundInputs, _inputNames, _outputNames,
+				"update function");
+
+		// Check the inputs of the update function
+		checkInput(_inputs, Expression.DataType.MATRIX, Statement.PS_FEATURES);
+		checkInput(_inputs, Expression.DataType.MATRIX, Statement.PS_LABELS);
+		checkInput(_inputs, Expression.DataType.LIST, Statement.PS_MODEL);
+		if (hyperParams != null) {
+			checkInput(_inputs, Expression.DataType.LIST, Statement.PS_HYPER_PARAMS);
+		}
+
+		// Check the output of the update function
+		if (_outputs.size() != 1) {
+			throw new DMLRuntimeException(
+				String.format("The output of the '%s' function should provide one list containing the gradients.", updFunc));
+		}
+		if (_outputs.get(0).getDataType() != Expression.DataType.LIST) {
+			throw new DMLRuntimeException(
+					String.format("The output of the '%s' function should be type of list.", updFunc));
+		}
+	}
+
+	private void checkInput(ArrayList<DataIdentifier> _inputs, Expression.DataType dt, String pname) {
+		if (_inputs.stream().filter(input -> input.getDataType() == dt && pname.equals(input.getName())).count() != 1) {
+			throw new DMLRuntimeException(
+				String.format("The '%s' function should provide an input of '%s' type named '%s'.", _updFunc, dt, pname));
+		}
+	}
+
+	public void setFeatures(MatrixObject features) {
+		this._features = features;
+	}
+
+	public void setLabels(MatrixObject labels) {
+		this._labels = labels;
+	}
+
+	public void setValFeatures(MatrixObject valFeatures) {
+		this._valFeatures = valFeatures;
+	}
+
+	public void setValLabels(MatrixObject valLabels) {
+		this._valLabels = valLabels;
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
new file mode 100644
index 0000000..6e1cd13
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.stream.Collectors;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.DataIdentifier;
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public abstract class ParamServer {
+
+	public class Gradient {
+		final long _workerID;
+		final ListObject _gradients;
+
+		public Gradient(long workerID, ListObject gradients) {
+			this._workerID = workerID;
+			this._gradients = gradients;
+		}
+	}
+
+	Queue<Gradient> _queue;
+	final Object _lock = new Object();
+	private ListObject _model;
+	private AggregationService _aggService;
+	private Thread _aggThread;
+	private boolean[] _pulledStates;
+
+	protected ParamServer(ListObject model, String aggFunc, Statement.PSFrequency freq,
+			Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum, ListObject hyperParams) {
+		this._queue = new ConcurrentLinkedQueue<>();
+		this._model = model;
+		this._aggService = new AggregationService(aggFunc, freq, updateType, ec, workerNum, hyperParams);
+		this._pulledStates = new boolean[workerNum];
+		this._aggThread = new Thread(_aggService);
+	}
+
+	public abstract void push(long workerID, ListObject value);
+
+	public abstract Data pull(long workerID);
+
+	public void start() {
+		_aggService._alive = true;
+		_aggThread.start();
+	}
+
+	public void stop() {
+		_aggService._alive = false;
+		try {
+			_aggThread.join();
+		} catch (InterruptedException e) {
+			throw new DMLRuntimeException("Parameter server: failed when stopping the server.", e);
+		}
+	}
+
+	public ListObject getResult() {
+		return _model;
+	}
+
+	public boolean getPulledState(int workerID) {
+		return _pulledStates[workerID];
+	}
+
+	public void setPulledState(int workerID, boolean state) {
+		_pulledStates[workerID] = state;
+	}
+
+	private void resetPulledStates() {
+		_pulledStates = new boolean[_pulledStates.length];
+	}
+
+	/**
+	 * Inner aggregation service which is for updating the model
+	 */
+	@SuppressWarnings("unused")
+	private class AggregationService implements Runnable {
+
+		protected final Log LOG = LogFactory.getLog(AggregationService.class.getName());
+
+		protected ExecutionContext _ec;
+		private Statement.PSFrequency _freq;
+		private Statement.PSUpdateType _updateType;
+		private FunctionCallCPInstruction _inst;
+		private DataIdentifier _output;
+		private boolean _alive;
+		private boolean[] _finishedStates;  // Workers' finished states
+
+		AggregationService(String aggFunc, Statement.PSFrequency freq, Statement.PSUpdateType updateType,
+				ExecutionContext ec, int workerNum, ListObject hyperParams) {
+			_ec = ExecutionContextFactory.createContext(ec.getProgram());
+			_freq = freq;
+			_updateType = updateType;
+			if (hyperParams != null) {
+				_ec.setVariable(Statement.PS_HYPER_PARAMS, hyperParams);
+			}
+			_finishedStates = new boolean[workerNum];
+
+			// Fetch the aggregation function
+			String[] keys = DMLProgram.splitFunctionKey(aggFunc);
+			String funcName = keys[0];
+			String funcNS = null;
+			if (keys.length == 2) {
+				funcNS = keys[0];
+				funcName = keys[1];
+			}
+			FunctionProgramBlock func = _ec.getProgram().getFunctionProgramBlock(funcNS, funcName);
+			ArrayList<DataIdentifier> inputs = func.getInputParams();
+			ArrayList<DataIdentifier> outputs = func.getOutputParams();
+
+			// Check the output of the aggregation function
+			if (outputs.size() != 1) {
+				throw new DMLRuntimeException(String.format(
+						"The output of the '%s' function should provide one list containing the updated model.",
+						aggFunc));
+			}
+			if (outputs.get(0).getDataType() != Expression.DataType.LIST) {
+				throw new DMLRuntimeException(
+						String.format("The output of the '%s' function should be type of list.", aggFunc));
+			}
+			_output = outputs.get(0);
+
+			CPOperand[] boundInputs = inputs.stream()
+					.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+					.toArray(CPOperand[]::new);
+			ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName)
+					.collect(Collectors.toCollection(ArrayList::new));
+			ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
+					.collect(Collectors.toCollection(ArrayList::new));
+			_inst = new FunctionCallCPInstruction(funcNS, funcName, boundInputs, inputNames, outputNames,
+					"aggregate function");
+		}
+
+		boolean isAlive() {
+			return _alive;
+		}
+
+		private boolean allFinished() {
+			return !ArrayUtils.contains(_finishedStates, false);
+		}
+
+		private void resetFinishedStates() {
+			Arrays.fill(_finishedStates, false);
+		}
+
+		private void setFinishedState(int workerID) {
+			_finishedStates[workerID] = true;
+		}
+
+		@Override
+		public void run() {
+			synchronized (_lock) {
+				while (isAlive()) {
+					do {
+						while (_queue.isEmpty()) {
+							try {
+								_lock.wait();
+							} catch (InterruptedException e) {
+								throw new DMLRuntimeException(
+										"Aggregation service: error when waiting for the coming gradients.", e);
+							}
+						}
+						Gradient p = _queue.remove();
+						if (LOG.isDebugEnabled()) {
+							LOG.debug(String.format("Successfully pulled the gradients [size:%d kb] of worker_%d.",
+									p._gradients.getDataSize() / 1024, p._workerID));
+						}
+
+						setFinishedState((int) p._workerID);
+
+						// Populate the variables table with the gradients and model
+						_ec.setVariable(Statement.PS_GRADIENTS, p._gradients);
+						_ec.setVariable(Statement.PS_MODEL, _model);
+
+						// Invoke the aggregate function
+						_inst.processInstruction(_ec);
+
+						// Get the output
+						ListObject newModel = (ListObject) _ec.getVariable(_output.getName());
+
+						// Update the model with the new output
+						ParamservUtils.cleanupListObject(_ec, _model);
+						ParamservUtils.cleanupListObject(_ec, p._gradients);
+						_model = newModel;
+
+					} while (!allFinished());
+
+					// notify all the workers to get the updated model
+					resetPulledStates();
+					resetFinishedStates();
+					_lock.notifyAll();
+					if (LOG.isDebugEnabled()) {
+						LOG.debug("Global parameter is broadcasted successfully.");
+					}
+				}
+			}
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
new file mode 100644
index 0000000..54c5d6c
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.MetaDataFormat;
+import org.apache.sysml.runtime.matrix.data.InputInfo;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
+
+public class ParamservUtils {
+
+	/**
+	 * Deep copy the list object
+	 *
+	 * @param lo list object
+	 * @return a new copied list object
+	 */
+	public static ListObject copyList(ListObject lo) {
+		if (lo.getLength() == 0) {
+			return lo;
+		}
+		List<Data> newData = lo.getNames().stream().map(name -> {
+			Data oldData = lo.slice(name);
+			if (oldData instanceof MatrixObject) {
+				MatrixObject mo = (MatrixObject) oldData;
+				return sliceMatrix(mo, 1, mo.getNumRows());
+			} else if (oldData instanceof ListObject || oldData instanceof FrameObject) {
+				throw new DMLRuntimeException("Copy list: does not support list or frame.");
+			} else {
+				return oldData;
+			}
+		}).collect(Collectors.toList());
+		return new ListObject(newData, lo.getNames());
+	}
+
+	public static void cleanupListObject(ExecutionContext ec, ListObject lo) {
+		ec.getVariables().removeAllIn(new HashSet<>(lo.getNames()));
+		lo.getData().forEach(ParamservUtils::cleanupData);
+	}
+
+	public static void cleanupData(Data data) {
+		if( !(data instanceof CacheableData) )
+			return;
+		CacheableData<?> cd = (CacheableData<?>) data;
+		cd.enableCleanup(true);
+		cd.clearData();
+	}
+
+	/**
+	 * Slice the matrix
+	 * @param mo input matrix
+	 * @param rl low boundary
+	 * @param rh high boundary
+	 * @return new sliced matrix
+	 */
+	public static MatrixObject sliceMatrix(MatrixObject mo, long rl, long rh) {
+		MatrixObject result = new MatrixObject(Expression.ValueType.DOUBLE, null,
+			new MetaDataFormat(new MatrixCharacteristics(-1, -1, -1, -1),
+				OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
+		MatrixBlock tmp = mo.acquireRead();
+		result.acquireModify(tmp.slice((int)rl-1, (int)rh-1, 0,
+			tmp.getNumColumns()-1, new MatrixBlock()));
+		mo.release();
+		result.release();
+		return result;
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java
index 1ca8eab..22b79b0 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java
@@ -46,7 +46,7 @@ public class CPOperand
 		this(name, vt, dt, false);
 	}
 
-	private CPOperand(String name, ValueType vt, DataType dt, boolean literal) {
+	public CPOperand(String name, ValueType vt, DataType dt, boolean literal) {
 		_name = name;
 		_valueType = vt;
 		_dataType = dt;

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
index 95f03b5..670190c 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
@@ -25,6 +25,7 @@ import java.util.List;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
 
 public class ListObject extends Data {
 	private static final long serialVersionUID = 3652422061598967358L;
@@ -107,6 +108,19 @@ public class ListObject extends Data {
 		return (_names == null) ? null : _names.get(ix);
 	}
 
+	public boolean isNamedList() {
+		return _names != null;
+	}
+
+	public List<Data> getData() {
+		return _data;
+	}
+
+	public long getDataSize() {
+		return _data.stream().filter(data -> data instanceof CacheableData)
+				.map(data -> ((CacheableData) data).getDataSize()).reduce((l1, l2) -> l1 + l2).get();
+	}
+
 	@Override
 	public String getDebugName() {
 		return toString();

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java
index 51cc4c1..4e5d4c0 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java
@@ -34,8 +34,8 @@ import org.apache.sysml.utils.Statistics;
 
 public final class MatrixIndexingCPInstruction extends IndexingCPInstruction {
 
-	protected MatrixIndexingCPInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl,
-			CPOperand cu, CPOperand out, String opcode, String istr) {
+	public MatrixIndexingCPInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu,
+			CPOperand out, String opcode, String istr) {
 		super(in, rl, ru, cl, cu, out, opcode, istr);
 	}
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index ddc56ae..3ab0fc8 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -19,14 +19,62 @@
 
 package org.apache.sysml.runtime.instructions.cp;
 
+import static org.apache.sysml.parser.Statement.PSFrequency;
+import static org.apache.sysml.parser.Statement.PSModeType;
+import static org.apache.sysml.parser.Statement.PSScheme;
+import static org.apache.sysml.parser.Statement.PSUpdateType;
+import static org.apache.sysml.parser.Statement.PS_AGGREGATION_FUN;
+import static org.apache.sysml.parser.Statement.PS_BATCH_SIZE;
+import static org.apache.sysml.parser.Statement.PS_EPOCHS;
+import static org.apache.sysml.parser.Statement.PS_FEATURES;
+import static org.apache.sysml.parser.Statement.PS_FREQUENCY;
+import static org.apache.sysml.parser.Statement.PS_HYPER_PARAMS;
+import static org.apache.sysml.parser.Statement.PS_LABELS;
+import static org.apache.sysml.parser.Statement.PS_MODE;
+import static org.apache.sysml.parser.Statement.PS_MODEL;
+import static org.apache.sysml.parser.Statement.PS_PARALLELISM;
+import static org.apache.sysml.parser.Statement.PS_SCHEME;
+import static org.apache.sysml.parser.Statement.PS_UPDATE_FUN;
+import static org.apache.sysml.parser.Statement.PS_UPDATE_TYPE;
+import static org.apache.sysml.parser.Statement.PS_VAL_FEATURES;
+import static org.apache.sysml.parser.Statement.PS_VAL_LABELS;
+
+import java.util.ArrayList;
 import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
-import org.apache.sysml.parser.Statement;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysml.runtime.matrix.operators.Operator;
+import org.apache.sysml.utils.NativeHelper;
 
 public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruction {
 
+	private static final int DEFAULT_BATCH_SIZE = 64;
+	private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = PSFrequency.BATCH;
+	private static final int DEFAULT_LEVEL_PARALLELISM = InfrastructureAnalyzer.getLocalParallelism();
+	private static final PSScheme DEFAULT_SCHEME = PSScheme.DISJOINT_CONTIGUOUS;
+
+	//internal local debug level
+	private static final boolean LDEBUG = false;
+
+	static {
+		// for internal debugging only
+		if (LDEBUG) {
+			Logger.getLogger("org.apache.sysml.runtime.controlprogram.paramserv").setLevel((Level) Level.DEBUG);
+		}
+	}
+
 	protected ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out,
 			String opcode, String istr) {
 		super(op, paramsMap, out, opcode, istr);
@@ -34,8 +82,209 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 
 	@Override
 	public void processInstruction(ExecutionContext ec) {
-		ListObject model = (ListObject) ec.getVariable(getParam(Statement.PS_MODEL));
-		ListObject outList = model.slice(0, model.getLength() - 1);
-		ec.setVariable(output.getName(), outList);
+
+		PSModeType mode = PSModeType.valueOf(getParam(PS_MODE));
+		int workerNum = getWorkerNum(mode);
+		String updFunc = getParam(PS_UPDATE_FUN);
+		String aggFunc = getParam(PS_AGGREGATION_FUN);
+		PSFrequency freq = getFrequency();
+		PSUpdateType updateType = getUpdateType();
+		int epochs = Integer.valueOf(getParam(PS_EPOCHS));
+		if (epochs <= 0) {
+			throw new DMLRuntimeException(
+					String.format("Paramserv function: The argument '%s' could not be less than or equal to 0.",
+							PS_EPOCHS));
+		}
+		long batchSize = getBatchSize();
+
+		// Create the parameter server
+		ListObject model = ec.getListObject(getParam(PS_MODEL));
+		ListObject hyperParams = getHyperParams(ec);
+		ParamServer ps = createPS(mode, aggFunc, freq, updateType, workerNum, model, ec, hyperParams);
+
+		// Create the local workers
+		List<LocalPSWorker> workers = IntStream.range(0, workerNum)
+				.mapToObj(i -> new LocalPSWorker((long) i, updFunc, freq, epochs, batchSize, hyperParams, ec, ps))
+				.collect(Collectors.toList());
+
+		// Do data partition
+		doDataPartition(ec, workers);
+
+		// Create the worker threads
+		List<Thread> threads = workers.stream().map(Thread::new).collect(Collectors.toList());
+
+		// Start the ps
+		ps.start();
+
+		// Start the workers
+		threads.forEach(Thread::start);
+
+		// Wait for the workers stopping
+		threads.forEach(thread -> {
+			try {
+				thread.join();
+			} catch (InterruptedException e) {
+				throw new DMLRuntimeException("Paramserv function: Failed to join the worker threads.", e);
+			}
+		});
+
+		ps.stop();
+
+		// Create the output
+		ListObject result = ps.getResult();
+		ec.setVariable(output.getName(), result);
+	}
+
+	private PSUpdateType getUpdateType() {
+		PSUpdateType updType = PSUpdateType.valueOf(getParam(PS_UPDATE_TYPE));
+		switch (updType) {
+		case ASP:
+		case SSP:
+			throw new DMLRuntimeException(String.format("Not support update type '%s'.", updType));
+		case BSP:
+			break;
+		}
+		return updType;
+	}
+
+	private PSFrequency getFrequency() {
+		if (!getParameterMap().containsKey(PS_FREQUENCY)) {
+			return DEFAULT_UPDATE_FREQUENCY;
+		}
+		PSFrequency freq = PSFrequency.valueOf(getParam(PS_FREQUENCY));
+		switch (freq) {
+		case EPOCH:
+			throw new DMLRuntimeException("Not support epoch update frequency.");
+		case BATCH:
+			break;
+		}
+		return freq;
+	}
+
+	/**
+	 * Get the worker numbers according to the vcores
+	 *
+	 * @param mode execution mode
+	 * @return worker numbers
+	 */
+	private int getWorkerNum(PSModeType mode) {
+		int workerNum = DEFAULT_LEVEL_PARALLELISM;
+		if (getParameterMap().containsKey(PS_PARALLELISM)) {
+			workerNum = Integer.valueOf(getParam(PS_PARALLELISM));
+		}
+		switch (mode) {
+		case LOCAL:
+			//FIXME: this is a workaround for a maximum number of buffers in openblas
+			//However, the root cause is a missing function preparation for each worker
+			//(i.e., deep copy with unique file names, and reduced degree of parallelism)
+			int vcores = InfrastructureAnalyzer.getLocalParallelism();
+			if ("openblas".equals(NativeHelper.getCurrentBLAS())) {
+				workerNum = Math.min(workerNum, vcores / 2);
+			} else {
+				workerNum = Math.min(workerNum, vcores);
+			}
+			break;
+		case REMOTE_SPARK:
+			throw new DMLRuntimeException("Do not support remote spark.");
+		}
+		return workerNum;
+	}
+
+	/**
+	 * Create a server which serves the local or remote workers
+	 *
+	 * @return parameter server
+	 */
+	private ParamServer createPS(PSModeType mode, String aggFunc, PSFrequency freq, PSUpdateType updateType,
+			int workerNum, ListObject model, ExecutionContext ec, ListObject hyperParams) {
+		ParamServer ps = null;
+		switch (mode) {
+		case LOCAL:
+			ps = new LocalParamServer(model, aggFunc, freq, updateType, ec, workerNum, hyperParams);
+			break;
+		case REMOTE_SPARK:
+			throw new DMLRuntimeException("Do not support remote spark.");
+		}
+		return ps;
+	}
+
+	private long getBatchSize() {
+		if (!getParameterMap().containsKey(PS_BATCH_SIZE)) {
+			return DEFAULT_BATCH_SIZE;
+		}
+		long batchSize = Integer.valueOf(getParam(PS_BATCH_SIZE));
+		if (batchSize <= 0) {
+			throw new DMLRuntimeException(String.format(
+					"Paramserv function: the number of argument '%s' could not be less than or equal to 0.",
+					PS_BATCH_SIZE));
+		}
+		return batchSize;
+	}
+
+	private ListObject getHyperParams(ExecutionContext ec) {
+		ListObject hyperparams = null;
+		if (getParameterMap().containsKey(PS_HYPER_PARAMS)) {
+			hyperparams = ec.getListObject(getParam(PS_HYPER_PARAMS));
+		}
+		return hyperparams;
+	}
+
+	private void doDataPartition(ExecutionContext ec, List<LocalPSWorker> workers) {
+		MatrixObject features = ec.getMatrixObject(getParam(PS_FEATURES));
+		MatrixObject labels = ec.getMatrixObject(getParam(PS_LABELS));
+		MatrixObject valFeatures = ec.getMatrixObject(getParam(PS_VAL_FEATURES));
+		MatrixObject valLabels = ec.getMatrixObject(getParam(PS_VAL_LABELS));
+		PSScheme scheme = DEFAULT_SCHEME;
+		if (getParameterMap().containsKey(PS_SCHEME)) {
+			scheme = PSScheme.valueOf(getParam(PS_SCHEME));
+		}
+		switch (scheme) {
+		case DISJOINT_CONTIGUOUS:
+			disjointContiguous(features, labels, valFeatures, valLabels, workers);
+			break;
+		case DISJOINT_RANDOM:
+		case OVERLAP_RESHUFFLE:
+		case DISJOINT_ROUND_ROBIN:
+			throw new DMLRuntimeException(
+					String.format("Paramserv function: the scheme '%s' is not supported.", scheme));
+		}
+	}
+
+	private void disjointContiguous(MatrixObject features, MatrixObject labels, MatrixObject valFeatures,
+			MatrixObject valLabels, List<LocalPSWorker> workers) {
+		// training data
+		List<MatrixObject> pfs = disjointContiguous(workers.size(), features);
+		List<MatrixObject> pls = disjointContiguous(workers.size(), labels);
+		if (pfs.size() < workers.size()) {
+			LOG.warn(String.format(
+					"There is only %d batches of data but has %d workers. Hence, reset the number of workers with %d.",
+					pfs.size(), workers.size(), pfs.size()));
+			workers = workers.subList(0, pfs.size());
+		}
+		for (int i = 0; i < workers.size(); i++) {
+			workers.get(i).setFeatures(pfs.get(i));
+			workers.get(i).setLabels(pls.get(i));
+		}
+
+		// validation data
+		List<MatrixObject> pvfs = disjointContiguous(workers.size(), valFeatures);
+		List<MatrixObject> pvls = disjointContiguous(workers.size(), valLabels);
+		for (int i = 0; i < workers.size(); i++) {
+			workers.get(i).setValFeatures(pvfs.get(i));
+			workers.get(i).setValLabels(pvls.get(i));
+		}
+	}
+
+	private List<MatrixObject> disjointContiguous(int workerNum, MatrixObject mo) {
+		List<MatrixObject> list = new ArrayList<>();
+		long stepSize = (long) Math.ceil(mo.getNumRows() / workerNum);
+		long begin = 1;
+		while (begin < mo.getNumRows()) {
+			long end = Math.min(begin + stepSize, mo.getNumRows());
+			MatrixObject pmo = ParamservUtils.sliceMatrix(mo, begin, end);
+			list.add(pmo);
+			begin = end + 1;
+		}
+		return list;
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
index 47ea66e..43f5229 100644
--- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
@@ -1250,9 +1250,11 @@ public abstract class AutomatedTestBase
 			if (exceptionExpected)
 				fail("expected exception which has not been raised: " + expectedException);
 		} catch (Exception e) {
-			if (exceptionExpected && e.getClass().equals(expectedException) && errMessage != null
-					&& !e.getMessage().contains(errMessage)) {
-				fail("expected exception message has not been raised: " + errMessage);
+			if (errMessage != null && !errMessage.equals("")) {
+				boolean result = rCompareException(exceptionExpected, errMessage, e, false);
+				if (exceptionExpected && !result) {
+					fail(String.format("expected exception message '%s' has not been raised.", errMessage));
+				}
 			}
 			if (!exceptionExpected || (expectedException != null && !(e.getClass().equals(expectedException)))) {
 				e.printStackTrace();
@@ -1269,6 +1271,16 @@ public abstract class AutomatedTestBase
 		}
 	}
 
+	private boolean rCompareException(boolean exceptionExpected, String errMessage, Throwable e, boolean result) {
+		if (e.getCause() != null) {
+			result |= rCompareException(exceptionExpected, errMessage, e.getCause(), result);
+		}
+		if (exceptionExpected && errMessage != null && e.getMessage().contains(errMessage)) {
+			result = true;
+		}
+		return result;
+	}
+
 	public void cleanupScratchSpace()
 	{
 		try

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
index 1b227f1..6370099 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
@@ -32,7 +32,8 @@ public class ParamservFuncTest extends AutomatedTestBase {
 	private static final String TEST_NAME4 = "paramserv-wrong-type-args";
 	private static final String TEST_NAME5 = "paramserv-wrong-args";
 	private static final String TEST_NAME6 = "paramserv-wrong-args2";
-	private static final String TEST_NAME7 = "paramserv-ipa-test";
+	private static final String TEST_NAME7 = "paramserv-nn-test";
+	private static final String TEST_NAME8 = "paramserv-minimum-version";
 
 	private static final String TEST_DIR = "functions/paramserv/";
 	private static final String TEST_CLASS_DIR = TEST_DIR + ParamservFuncTest.class.getSimpleName() + "/";
@@ -48,53 +49,59 @@ public class ParamservFuncTest extends AutomatedTestBase {
 		addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {}));
 		addTestConfiguration(TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {}));
 		addTestConfiguration(TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {}));
+		addTestConfiguration(TEST_NAME8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME8, new String[] {}));
 	}
 
 	@Test
 	public void testParamservWithAllArgs() {
-		runDMLTest(TEST_NAME1, true, false, null, null);
+		runDMLTest(TEST_NAME1, false, null, null);
 	}
 
 	@Test
 	public void testParamservWithoutOptionalArgs() {
-		runDMLTest(TEST_NAME2, true, false, null, null);
+		runDMLTest(TEST_NAME2, false, null, null);
 	}
 
 	@Test
 	public void testParamservMissArgs() {
 		final String errmsg = "Named parameter 'features' missing. Please specify the input.";
-		runDMLTest(TEST_NAME3, true, true, DMLException.class, errmsg);
+		runDMLTest(TEST_NAME3, true, DMLException.class, errmsg);
 	}
 
 	@Test
 	public void testParamservWrongTypeArgs() {
 		final String errmsg = "Input to PARAMSERV::model must be of type 'LIST'. It should not be of type 'MATRIX'";
-		runDMLTest(TEST_NAME4, true, true, DMLException.class, errmsg);
+		runDMLTest(TEST_NAME4, true, DMLException.class, errmsg);
 	}
 
 	@Test
 	public void testParamservWrongArgs() {
 		final String errmsg = "Function PARAMSERV does not support value 'NSP' as the 'utype' parameter.";
-		runDMLTest(TEST_NAME5, true, true, DMLException.class, errmsg);
+		runDMLTest(TEST_NAME5, true, DMLException.class, errmsg);
 	}
 
 	@Test
 	public void testParamservWrongArgs2() {
 		final String errmsg = "Invalid parameters for PARAMSERV: [modelList, val_featur=X_val]";
-		runDMLTest(TEST_NAME6, true, true, DMLException.class, errmsg);
+		runDMLTest(TEST_NAME6, true, DMLException.class, errmsg);
 	}
 
 	@Test
-	public void testParamservIpaTest() {
-		runDMLTest(TEST_NAME7, true, false, null, "1");
+	public void testParamservNNTest() {
+		runDMLTest(TEST_NAME7, false, null, null);
 	}
 
-	private void runDMLTest(String testname, boolean newWay, boolean exceptionExpected, Class<?> exceptionClass,
+	@Test
+	public void testParamservMinimumVersionTest() {
+		runDMLTest(TEST_NAME8, false, null, null);
+	}
+
+	private void runDMLTest(String testname, boolean exceptionExpected, Class<?> exceptionClass,
 			String errmsg) {
 		TestConfiguration config = getTestConfiguration(testname);
 		loadTestConfiguration(config);
 		programArgs = new String[] { "-explain" };
 		fullDMLScriptName = HOME + testname + ".dml";
-		runTest(newWay, exceptionExpected, exceptionClass, errmsg, -1);
+		runTest(true, exceptionExpected, exceptionClass, errmsg, -1);
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
new file mode 100644
index 0000000..2a3bbe2
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
@@ -0,0 +1,383 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * MNIST LeNet Example
+ */
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/conv2d_builtin.dml") as conv2d
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/dropout.dml") as dropout
+source("nn/layers/l2_reg.dml") as l2_reg
+source("nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+train = function(matrix[double] X, matrix[double] Y,
+                 matrix[double] X_val, matrix[double] Y_val,
+                 int C, int Hin, int Win, int epochs, int workers)
+    return (matrix[double] W1, matrix[double] b1,
+            matrix[double] W2, matrix[double] b2,
+            matrix[double] W3, matrix[double] b3,
+            matrix[double] W4, matrix[double] b4) {
+  /*
+   * Trains a convolutional net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.  The targets, Y, have K
+   * classes, and are one-hot encoded.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+   *  - Y_val: Target validation matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - epochs: Total number of full training loops over the full data set.
+   *
+   * Outputs:
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   */
+  N = nrow(X)
+  K = ncol(Y)
+
+  # Create network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf)  # inputs: (N, C*Hin*Win)
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf)  # inputs: (N, F1*(Hin/2)*(Win/2))
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3)  # inputs: (N, F2*(Hin/2/2)*(Win/2/2))
+  [W4, b4] = affine::init(N3, K)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  lr = 0.01  # learning rate
+  mu = 0.9  #0.5  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+  # Regularization
+  lambda = 5e-04
+
+  # Create the model object
+  modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+
+  # Create the hyper parameter list
+  params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+
+  # Use paramserv function
+  modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation", mode="LOCAL", utype="BSP", freq="BATCH", epochs=epochs, batchsize=64, k=workers, scheme="DISJOINT_CONTIGUOUS", hyperparams=params, checkpointing="NONE")
+
+  W1 = as.matrix(modelList2["W1"])
+  b1 = as.matrix(modelList2["b1"])
+  W2 = as.matrix(modelList2["W2"])
+  b2 = as.matrix(modelList2["b2"])
+  W3 = as.matrix(modelList2["W3"])
+  b3 = as.matrix(modelList2["b3"])
+  W4 = as.matrix(modelList2["W4"])
+  b4 = as.matrix(modelList2["b4"])
+
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(matrix[double] features,
+                     matrix[double] labels,
+                     list[unknown] hyperparams,
+                     list[unknown] model)
+          return (list[unknown] gradients) {
+
+# PB: not be able to get scalar from list
+
+  C = 1
+  Hin = 28
+  Win = 28
+  Hf = 5
+  Wf = 5
+  stride = 1
+  pad = 2
+  lambda = 5e-04
+  F1 = 32
+  F2 = 64
+  N3 = 512
+  W1 = as.matrix(model["W1"])
+  b1 = as.matrix(model["b1"])
+  W2 = as.matrix(model["W2"])
+  b2 = as.matrix(model["b2"])
+  W3 = as.matrix(model["W3"])
+  b3 = as.matrix(model["b3"])
+  W4 = as.matrix(model["W4"])
+  b4 = as.matrix(model["b4"])
+
+  # Compute forward pass
+  ## layer 1: conv1 -> relu1 -> pool1
+  [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, Wf,
+                                              stride, stride, pad, pad)
+  outr1 = relu::forward(outc1)
+  [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2,
+                                                strideh=2, stridew=2, pad=0, pad=0)
+  ## layer 2: conv2 -> relu2 -> pool2
+  [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+                                            stride, stride, pad, pad)
+  outr2 = relu::forward(outc2)
+  [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2,
+                                                strideh=2, stridew=2, pad=0, pad=0)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  outa3 = affine::forward(outp2, W3, b3)
+  outr3 = relu::forward(outa3)
+  [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+  ## layer 4:  affine4 -> softmax
+  outa4 = affine::forward(outd3, W4, b4)
+  probs = softmax::forward(outa4)
+
+  # Compute data backward pass
+  ## loss:
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  ## layer 4:  affine4 -> softmax
+  douta4 = softmax::backward(dprobs, outa4)
+  [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+  douta3 = relu::backward(doutr3, outa3)
+  [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+  ## layer 2: conv2 -> relu2 -> pool2
+  doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2,
+                                strideh=2, stridew=2, pad=0, pad=0)
+  doutc2 = relu::backward(doutr2, outc2)
+  [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, F1,
+                                        Houtp1, Woutp1, Hf, Wf, stride, stride, pad, pad)
+  ## layer 1: conv1 -> relu1 -> pool1
+  doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2,
+                                strideh=2, stridew=2, pad=0, pad=0)
+  doutc1 = relu::backward(doutr1, outc1)
+  [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, W1, b1, C, Hin, Win,
+                                          Hf, Wf, stride, stride, pad, pad)
+
+  # Compute regularization backward pass
+  dW1_reg = l2_reg::backward(W1, lambda)
+  dW2_reg = l2_reg::backward(W2, lambda)
+  dW3_reg = l2_reg::backward(W3, lambda)
+  dW4_reg = l2_reg::backward(W4, lambda)
+  dW1 = dW1 + dW1_reg
+  dW2 = dW2 + dW2_reg
+  dW3 = dW3 + dW3_reg
+  dW4 = dW4 + dW4_reg
+
+  gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, db3=db3, db4=db4)
+}
+
+# PB: how to handle the velocity? (put into the model)
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+                       list[unknown] gradients,
+                       list[unknown] hyperparams)
+   return (list[unknown] modelResult) {
+
+     W1 = as.matrix(model["W1"])
+     W2 = as.matrix(model["W2"])
+     W3 = as.matrix(model["W3"])
+     W4 = as.matrix(model["W4"])
+     b1 = as.matrix(model["b1"])
+     b2 = as.matrix(model["b2"])
+     b3 = as.matrix(model["b3"])
+     b4 = as.matrix(model["b4"])
+     dW1 = as.matrix(gradients["dW1"])
+     dW2 = as.matrix(gradients["dW2"])
+     dW3 = as.matrix(gradients["dW3"])
+     dW4 = as.matrix(gradients["dW4"])
+     db1 = as.matrix(gradients["db1"])
+     db2 = as.matrix(gradients["db2"])
+     db3 = as.matrix(gradients["db3"])
+     db4 = as.matrix(gradients["db4"])
+     vW1 = as.matrix(model["vW1"])
+     vW2 = as.matrix(model["vW2"])
+     vW3 = as.matrix(model["vW3"])
+     vW4 = as.matrix(model["vW4"])
+     vb1 = as.matrix(model["vb1"])
+     vb2 = as.matrix(model["vb2"])
+     vb3 = as.matrix(model["vb3"])
+     vb4 = as.matrix(model["vb4"])
+     lr = 0.01
+     mu = 0.9
+
+     # Optimize with SGD w/ Nesterov momentum
+     [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+     [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+     [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+     [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+     [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+     [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+     [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
+     [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
+
+     modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+   }
+
+predict = function(matrix[double] X, int C, int Hin, int Win,
+                   matrix[double] W1, matrix[double] b1,
+                   matrix[double] W2, matrix[double] b2,
+                   matrix[double] W3, matrix[double] b3,
+                   matrix[double] W4, matrix[double] b4)
+    return (matrix[double] probs) {
+  /*
+   * Computes the class probability predictions of a convolutional
+   * net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   *
+   * Outputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   */
+  N = nrow(X)
+
+  # Network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = nrow(W1)  # num conv filters in conv1
+  F2 = nrow(W2)  # num conv filters in conv2
+  N3 = ncol(W3)  # num nodes in affine3
+  K = ncol(W4)  # num nodes in affine4, equal to number of target dimensions (num classes)
+
+  # Compute predictions over mini-batches
+  probs = matrix(0, rows=N, cols=K)
+  batch_size = 64
+  iters = ceil(N / batch_size)
+  for(i in 1:iters) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1
+    end = min(N, beg + batch_size - 1)
+    X_batch = X[beg:end,]
+
+    # Compute forward pass
+    ## layer 1: conv1 -> relu1 -> pool1
+    [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, Hf, Wf, stride, stride,
+                                              pad, pad)
+    outr1 = relu::forward(outc1)
+    [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2,
+                                                  strideh=2, stridew=2, pad=0, pad=0)
+    ## layer 2: conv2 -> relu2 -> pool2
+    [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+                                              stride, stride, pad, pad)
+    outr2 = relu::forward(outc2)
+    [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2,
+                                                  strideh=2, stridew=2, pad=0, pad=0)
+    ## layer 3:  affine3 -> relu3
+    outa3 = affine::forward(outp2, W3, b3)
+    outr3 = relu::forward(outa3)
+    ## layer 4:  affine4 -> softmax
+    outa4 = affine::forward(outr3, W4, b4)
+    probs_batch = softmax::forward(outa4)
+
+    # Store predictions
+    probs[beg:end,] = probs_batch
+  }
+}
+
+eval = function(matrix[double] probs, matrix[double] Y)
+    return (double loss, double accuracy) {
+  /*
+   * Evaluates a convolutional net using the "LeNet" architecture.
+   *
+   * The probs matrix contains the class probability predictions
+   * of K classes over N examples.  The targets, Y, have K classes,
+   * and are one-hot encoded.
+   *
+   * Inputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   *  - Y: Target matrix, of shape (N, K).
+   *
+   * Outputs:
+   *  - loss: Scalar loss, of shape (1).
+   *  - accuracy: Scalar accuracy, of shape (1).
+   */
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, Y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(Y)
+  accuracy = mean(correct_pred)
+}
+
+generate_dummy_data = function()
+    return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) {
+  /*
+   * Generate a dummy dataset similar to the MNIST dataset.
+   *
+   * Outputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   */
+  # Generate dummy input data
+  N = 1024  # num examples
+  C = 1  # num input channels
+  Hin = 28  # input height
+  Win = 28  # input width
+  K = 10  # num target classes
+  X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
+  classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform"))
+  Y = table(seq(1, N), classes)  # one-hot encoding
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
new file mode 100644
index 0000000..2ef7411
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
@@ -0,0 +1,377 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * MNIST LeNet Example
+ */
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/conv2d_builtin.dml") as conv2d
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/dropout.dml") as dropout
+source("nn/layers/l2_reg.dml") as l2_reg
+source("nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+train = function(matrix[double] X, matrix[double] Y,
+                 matrix[double] X_val, matrix[double] Y_val,
+                 int C, int Hin, int Win, int epochs, int workers)
+    return (matrix[double] W1, matrix[double] b1,
+            matrix[double] W2, matrix[double] b2,
+            matrix[double] W3, matrix[double] b3,
+            matrix[double] W4, matrix[double] b4) {
+  /*
+   * Trains a convolutional net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.  The targets, Y, have K
+   * classes, and are one-hot encoded.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+   *  - Y_val: Target validation matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - epochs: Total number of full training loops over the full data set.
+   *
+   * Outputs:
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   */
+  N = nrow(X)
+  K = ncol(Y)
+
+  # Create network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf)  # inputs: (N, C*Hin*Win)
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf)  # inputs: (N, F1*(Hin/2)*(Win/2))
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3)  # inputs: (N, F2*(Hin/2/2)*(Win/2/2))
+  [W4, b4] = affine::init(N3, K)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  lr = 0.01  # learning rate
+  mu = 0.9  #0.5  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+  # Regularization
+  lambda = 5e-04
+
+  # Create the model object
+  modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+
+  # Create the hyper parameter list
+  params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+
+  # Use paramserv function
+  modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::aggregation", mode="LOCAL", utype="BSP", epochs=epochs, hyperparams=params)
+
+  W1 = as.matrix(modelList2["W1"])
+  b1 = as.matrix(modelList2["b1"])
+  W2 = as.matrix(modelList2["W2"])
+  b2 = as.matrix(modelList2["b2"])
+  W3 = as.matrix(modelList2["W3"])
+  b3 = as.matrix(modelList2["b3"])
+  W4 = as.matrix(modelList2["W4"])
+  b4 = as.matrix(modelList2["b4"])
+
+}
+
+gradients = function(matrix[double] features,
+                     matrix[double] labels,
+                     list[unknown] hyperparams,
+                     list[unknown] model)
+          return (list[unknown] gradients) {
+
+  C = 1
+  Hin = 28
+  Win = 28
+  Hf = 5
+  Wf = 5
+  stride = 1
+  pad = 2
+  lambda = 5e-04
+  F1 = 32
+  F2 = 64
+  N3 = 512
+  W1 = as.matrix(model["W1"])
+  b1 = as.matrix(model["b1"])
+  W2 = as.matrix(model["W2"])
+  b2 = as.matrix(model["b2"])
+  W3 = as.matrix(model["W3"])
+  b3 = as.matrix(model["b3"])
+  W4 = as.matrix(model["W4"])
+  b4 = as.matrix(model["b4"])
+
+  # Compute forward pass
+  ## layer 1: conv1 -> relu1 -> pool1
+  [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, Wf,
+                                              stride, stride, pad, pad)
+  outr1 = relu::forward(outc1)
+  [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2,
+                                                strideh=2, stridew=2, pad=0, pad=0)
+  ## layer 2: conv2 -> relu2 -> pool2
+  [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+                                            stride, stride, pad, pad)
+  outr2 = relu::forward(outc2)
+  [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2,
+                                                strideh=2, stridew=2, pad=0, pad=0)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  outa3 = affine::forward(outp2, W3, b3)
+  outr3 = relu::forward(outa3)
+  [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+  ## layer 4:  affine4 -> softmax
+  outa4 = affine::forward(outd3, W4, b4)
+  probs = softmax::forward(outa4)
+
+  # Compute data backward pass
+  ## loss:
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  ## layer 4:  affine4 -> softmax
+  douta4 = softmax::backward(dprobs, outa4)
+  [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+  douta3 = relu::backward(doutr3, outa3)
+  [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+  ## layer 2: conv2 -> relu2 -> pool2
+  doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2,
+                                strideh=2, stridew=2, pad=0, pad=0)
+  doutc2 = relu::backward(doutr2, outc2)
+  [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, F1,
+                                        Houtp1, Woutp1, Hf, Wf, stride, stride, pad, pad)
+  ## layer 1: conv1 -> relu1 -> pool1
+  doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2,
+                                strideh=2, stridew=2, pad=0, pad=0)
+  doutc1 = relu::backward(doutr1, outc1)
+  [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, W1, b1, C, Hin, Win,
+                                          Hf, Wf, stride, stride, pad, pad)
+
+  # Compute regularization backward pass
+  dW1_reg = l2_reg::backward(W1, lambda)
+  dW2_reg = l2_reg::backward(W2, lambda)
+  dW3_reg = l2_reg::backward(W3, lambda)
+  dW4_reg = l2_reg::backward(W4, lambda)
+  dW1 = dW1 + dW1_reg
+  dW2 = dW2 + dW2_reg
+  dW3 = dW3 + dW3_reg
+  dW4 = dW4 + dW4_reg
+
+  gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, db3=db3, db4=db4)
+
+}
+
+# how to handle the velocity?
+aggregation = function(list[unknown] model,
+                       list[unknown] gradients,
+                       list[unknown] hyperparams)
+   return (list[unknown] modelResult) {
+
+     W1 = as.matrix(model["W1"])
+     W2 = as.matrix(model["W2"])
+     W3 = as.matrix(model["W3"])
+     W4 = as.matrix(model["W4"])
+     b1 = as.matrix(model["b1"])
+     b2 = as.matrix(model["b2"])
+     b3 = as.matrix(model["b3"])
+     b4 = as.matrix(model["b4"])
+     dW1 = as.matrix(gradients["dW1"])
+     dW2 = as.matrix(gradients["dW2"])
+     dW3 = as.matrix(gradients["dW3"])
+     dW4 = as.matrix(gradients["dW4"])
+     db1 = as.matrix(gradients["db1"])
+     db2 = as.matrix(gradients["db2"])
+     db3 = as.matrix(gradients["db3"])
+     db4 = as.matrix(gradients["db4"])
+     vW1 = as.matrix(model["vW1"])
+     vW2 = as.matrix(model["vW2"])
+     vW3 = as.matrix(model["vW3"])
+     vW4 = as.matrix(model["vW4"])
+     vb1 = as.matrix(model["vb1"])
+     vb2 = as.matrix(model["vb2"])
+     vb3 = as.matrix(model["vb3"])
+     vb4 = as.matrix(model["vb4"])
+     lr = 0.01
+     mu = 0.9
+
+     # Optimize with SGD w/ Nesterov momentum
+     [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+     [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+     [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+     [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+     [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+     [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+     [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
+     [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
+
+     modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+   }
+
+predict = function(matrix[double] X, int C, int Hin, int Win,
+                   matrix[double] W1, matrix[double] b1,
+                   matrix[double] W2, matrix[double] b2,
+                   matrix[double] W3, matrix[double] b3,
+                   matrix[double] W4, matrix[double] b4)
+    return (matrix[double] probs) {
+  /*
+   * Computes the class probability predictions of a convolutional
+   * net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   *
+   * Outputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   */
+  N = nrow(X)
+
+  # Network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = nrow(W1)  # num conv filters in conv1
+  F2 = nrow(W2)  # num conv filters in conv2
+  N3 = ncol(W3)  # num nodes in affine3
+  K = ncol(W4)  # num nodes in affine4, equal to number of target dimensions (num classes)
+
+  # Compute predictions over mini-batches
+  probs = matrix(0, rows=N, cols=K)
+  batch_size = 64
+  iters = ceil(N / batch_size)
+  for(i in 1:iters) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1
+    end = min(N, beg + batch_size - 1)
+    X_batch = X[beg:end,]
+
+    # Compute forward pass
+    ## layer 1: conv1 -> relu1 -> pool1
+    [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, Hf, Wf, stride, stride,
+                                              pad, pad)
+    outr1 = relu::forward(outc1)
+    [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2,
+                                                  strideh=2, stridew=2, pad=0, pad=0)
+    ## layer 2: conv2 -> relu2 -> pool2
+    [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+                                              stride, stride, pad, pad)
+    outr2 = relu::forward(outc2)
+    [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2,
+                                                  strideh=2, stridew=2, pad=0, pad=0)
+    ## layer 3:  affine3 -> relu3
+    outa3 = affine::forward(outp2, W3, b3)
+    outr3 = relu::forward(outa3)
+    ## layer 4:  affine4 -> softmax
+    outa4 = affine::forward(outr3, W4, b4)
+    probs_batch = softmax::forward(outa4)
+
+    # Store predictions
+    probs[beg:end,] = probs_batch
+  }
+}
+
+eval = function(matrix[double] probs, matrix[double] Y)
+    return (double loss, double accuracy) {
+  /*
+   * Evaluates a convolutional net using the "LeNet" architecture.
+   *
+   * The probs matrix contains the class probability predictions
+   * of K classes over N examples.  The targets, Y, have K classes,
+   * and are one-hot encoded.
+   *
+   * Inputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   *  - Y: Target matrix, of shape (N, K).
+   *
+   * Outputs:
+   *  - loss: Scalar loss, of shape (1).
+   *  - accuracy: Scalar accuracy, of shape (1).
+   */
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, Y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(Y)
+  accuracy = mean(correct_pred)
+}
+
+generate_dummy_data = function()
+    return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) {
+  /*
+   * Generate a dummy dataset similar to the MNIST dataset.
+   *
+   * Outputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   */
+  # Generate dummy input data
+  N = 1024  # num examples
+  C = 1  # num input channels
+  Hin = 28  # input height
+  Win = 28  # input width
+  K = 10  # num target classes
+  X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
+  classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform"))
+  Y = table(seq(1, N), classes)  # one-hot encoding
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/paramserv-all-args.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-all-args.dml b/src/test/scripts/functions/paramserv/paramserv-all-args.dml
index bcb3ac3..ec6e087 100644
--- a/src/test/scripts/functions/paramserv/paramserv-all-args.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-all-args.dml
@@ -20,7 +20,7 @@
 #-------------------------------------------------------------
 
 e1 = "element1"
-paramsList = list(e1)
+paramsList = list(e1=e1)
 X = matrix(1, rows=2, cols=3)
 Y = matrix(2, rows=2, cols=3)
 X_val = matrix(3, rows=2, cols=3)
@@ -35,7 +35,7 @@ aggregation = function (matrix[double] input) return (matrix[double] output) {
 }
 
 e2 = "element2"
-hps = list(e2)
+hps = list(e2=e2)
 
 # Use paramserv function
 paramsList2 = paramserv(model=paramsList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=hps, checkpointing="NONE")

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml b/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml
deleted file mode 100644
index 5aed767..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml
+++ /dev/null
@@ -1,47 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-e1 = "element1"
-paramsList = list(e1)
-X = matrix(1, rows=2, cols=3)
-Y = matrix(2, rows=2, cols=3)
-X_val = matrix(3, rows=2, cols=3)
-Y_val = matrix(4, rows=2, cols=3)
-
-gradients = function (matrix[double] input) return (matrix[double] output) {
-  output = input
-}
-
-aggregation = function (matrix[double] input) return (matrix[double] output) {
-  output = input
-}
-
-e2 = "element2"
-hps = list(e2)
-
-# Use paramserv function
-paramsList2 = list(1, 2, 3)
-
-if (length(paramsList2) == 3) {
-  paramsList2 = paramserv(model=paramsList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=hps, checkpointing="NONE")
-}
-
-print(length(paramsList2))
\ No newline at end of file