You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/07/22 05:31:09 UTC

[1/2] systemml git commit: [SYSTEMML-2420, 2422] New distributed paramserv spark workers and rpc

Repository: systemml
Updated Branches:
  refs/heads/master 54dbe9bb2 -> 15ecb723e


http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml
deleted file mode 100644
index 8605984..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "EPOCH", batchsize,"DISJOINT_CONTIGUOUS", "LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml b/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml
new file mode 100644
index 0000000..4d0f32e
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml
@@ -0,0 +1,53 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+e1 = "element1"
+modelList = list(e1)
+X = matrix(1, rows=200, cols=30)
+Y = matrix(2, rows=200, cols=1)
+X_val = matrix(3, rows=200, cols=30)
+Y_val = matrix(4, rows=200, cols=1)
+
+gradients = function(matrix[double] features,
+                     matrix[double] labels,
+                     list[unknown] hyperparams,
+                     list[unknown] model)
+            return (list[unknown] gradients) {
+  gradients = model
+}
+
+aggregation = function(list[unknown] model,
+                       list[unknown] gradients,
+                       list[unknown] hyperparams)
+              return (list[unknown] modelResult) {
+  modelResult = model
+  print(toString(as.matrix(gradients["agg_service_err"])))
+}
+
+e2 = "element2"
+params = list(e2)
+
+modelList = list("model")
+
+# Use paramserv function
+modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="REMOTE_SPARK", utype="BSP", epochs=10, hyperparams=params, k=1)
+
+print(toString(as.matrix(modelList2[1])))
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-spark-nn-bsp-batch-dc.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-spark-nn-bsp-batch-dc.dml b/src/test/scripts/functions/paramserv/paramserv-spark-nn-bsp-batch-dc.dml
deleted file mode 100644
index 31d44aa..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-spark-nn-bsp-batch-dc.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 16
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_CONTIGUOUS", "REMOTE_SPARK")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml b/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml
new file mode 100644
index 0000000..ad16122
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml
@@ -0,0 +1,53 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+e1 = "element1"
+modelList = list(e1)
+X = matrix(1, rows=200, cols=30)
+Y = matrix(2, rows=200, cols=1)
+X_val = matrix(3, rows=200, cols=30)
+Y_val = matrix(4, rows=200, cols=1)
+
+gradients = function(matrix[double] features,
+                     matrix[double] labels,
+                     list[unknown] hyperparams,
+                     list[unknown] model)
+            return (list[unknown] gradients) {
+  gradients = model
+  print(toString(as.matrix(gradients["worker_err"])))
+}
+
+aggregation = function(list[unknown] model,
+                       list[unknown] gradients,
+                       list[unknown] hyperparams)
+              return (list[unknown] modelResult) {
+  modelResult = model
+}
+
+e2 = "element2"
+params = list(e2)
+
+modelList = list("model")
+
+# Use paramserv function
+modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="REMOTE_SPARK", utype="BSP", epochs=10, hyperparams=params, k=1)
+
+print(toString(as.matrix(modelList2[1])))
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-test.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-test.dml b/src/test/scripts/functions/paramserv/paramserv-test.dml
new file mode 100644
index 0000000..b21e9c0
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-test.dml
@@ -0,0 +1,48 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+
+# Generate the training data
+[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+n = nrow(images)
+
+# Generate the training data
+[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+
+# Split into training and validation
+val_size = n * 0.1
+X = images[(val_size+1):n,]
+X_val = images[1:val_size,]
+Y = labels[(val_size+1):n,]
+Y_val = labels[1:val_size,]
+
+# Train
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, $epochs, $workers, $utype, $freq, $batchsize, $scheme, $mode)
+
+# Compute validation loss & accuracy
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, $batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
+loss_val = cross_entropy_loss::forward(probs_val, Y_val)
+accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
+
+# Output results
+print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java
index d1b3a6d..a99035f 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java
@@ -30,9 +30,11 @@ import org.junit.runners.Suite;
 	SparkDataPartitionerTest.class,
 	ParamservSyntaxTest.class,
 	SerializationTest.class,
+	RpcObjectTest.class,
 	ParamservRecompilationTest.class,
 	ParamservRuntimeNegativeTest.class,
-	ParamservLocalNNTest.class
+	ParamservLocalNNTest.class,
+	ParamservSparkNNTest.class
 })
 
 


[2/2] systemml git commit: [SYSTEMML-2420, 2422] New distributed paramserv spark workers and rpc

Posted by mb...@apache.org.
[SYSTEMML-2420,2422] New distributed paramserv spark workers and rpc

Closes #805.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/15ecb723
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/15ecb723
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/15ecb723

Branch: refs/heads/master
Commit: 15ecb723e39e3154412ca8f8824c4554ee64ca35
Parents: 54dbe9b
Author: EdgarLGB <gu...@atos.net>
Authored: Sat Jul 21 22:31:36 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sat Jul 21 22:31:36 2018 -0700

----------------------------------------------------------------------
 .../controlprogram/paramserv/LocalPSWorker.java |  34 +++---
 .../paramserv/LocalParamServer.java             |   7 +-
 .../controlprogram/paramserv/PSWorker.java      |  15 ++-
 .../controlprogram/paramserv/ParamServer.java   |  39 ++++---
 .../paramserv/ParamservUtils.java               |  65 ++++++-----
 .../paramserv/spark/SparkPSBody.java            |   6 +-
 .../paramserv/spark/SparkPSProxy.java           |  68 +++++++++++
 .../paramserv/spark/SparkPSWorker.java          |  46 ++++++--
 .../paramserv/spark/rpc/PSRpcCall.java          |  97 ++++++++++++++++
 .../paramserv/spark/rpc/PSRpcFactory.java       |  57 ++++++++++
 .../paramserv/spark/rpc/PSRpcHandler.java       |  83 ++++++++++++++
 .../paramserv/spark/rpc/PSRpcObject.java        |  57 ++++++++++
 .../paramserv/spark/rpc/PSRpcResponse.java      | 112 +++++++++++++++++++
 .../cp/ParamservBuiltinCPInstruction.java       |  52 +++++++--
 .../sysml/runtime/util/ProgramConverter.java    |  11 +-
 .../java/org/apache/sysml/utils/Statistics.java |   6 +
 .../paramserv/ParamservLocalNNTest.java         |  41 +++----
 .../paramserv/ParamservSparkNNTest.java         |  68 +++++++++--
 .../functions/paramserv/RpcObjectTest.java      |  56 ++++++++++
 .../functions/paramserv/SerializationTest.java  |   2 +-
 .../paramserv/paramserv-nn-asp-batch.dml        |  53 ---------
 .../paramserv/paramserv-nn-asp-epoch.dml        |  53 ---------
 .../paramserv/paramserv-nn-bsp-batch-dc.dml     |  53 ---------
 .../paramserv/paramserv-nn-bsp-batch-dr.dml     |  53 ---------
 .../paramserv/paramserv-nn-bsp-batch-drr.dml    |  53 ---------
 .../paramserv/paramserv-nn-bsp-batch-or.dml     |  53 ---------
 .../paramserv/paramserv-nn-bsp-epoch.dml        |  53 ---------
 .../paramserv-spark-agg-service-failed.dml      |  53 +++++++++
 .../paramserv-spark-nn-bsp-batch-dc.dml         |  53 ---------
 .../paramserv/paramserv-spark-worker-failed.dml |  53 +++++++++
 .../functions/paramserv/paramserv-test.dml      |  48 ++++++++
 .../functions/paramserv/ZPackageSuite.java      |   4 +-
 32 files changed, 961 insertions(+), 543 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index bbf2dbe..c23943d 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -35,6 +35,9 @@ import org.apache.sysml.utils.Statistics;
 public class LocalPSWorker extends PSWorker implements Callable<Void> {
 
 	protected static final Log LOG = LogFactory.getLog(LocalPSWorker.class.getName());
+	private static final long serialVersionUID = 5195390748495357295L;
+
+	protected LocalPSWorker() {}
 
 	public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
 		MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) {
@@ -42,6 +45,11 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 	}
 
 	@Override
+	public String getWorkerName() {
+		return String.format("Local worker_%d", _workerID);
+	}
+
+	@Override
 	public Void call() throws Exception {
 		if (DMLScript.STATISTICS)
 			Statistics.incWorkerNumber();
@@ -60,10 +68,10 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 			}
 
 			if (LOG.isDebugEnabled()) {
-				LOG.debug(String.format("Local worker_%d: Job finished.", _workerID));
+				LOG.debug(String.format("%s: job finished.", getWorkerName()));
 			}
 		} catch (Exception e) {
-			throw new DMLRuntimeException(String.format("Local worker_%d failed", _workerID), e);
+			throw new DMLRuntimeException(String.format("%s failed", getWorkerName()), e);
 		}
 		return null;
 	}
@@ -93,7 +101,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 			ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL);
 
 			if (LOG.isDebugEnabled()) {
-				LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1));
+				LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1));
 			}
 		}
 
@@ -108,9 +116,9 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 			Statistics.accPSLocalModelUpdateTime((long) tUpd.stop());
 		
 		if (LOG.isDebugEnabled()) {
-			LOG.debug(String.format("Local worker_%d: Local global parameter [size:%d kb] updated. "
+			LOG.debug(String.format("%s: local global parameter [size:%d kb] updated. "
 				+ "[Epoch:%d  Total epoch:%d  Iteration:%d  Total iteration:%d]",
-				_workerID, globalParams.getDataSize(), i + 1, _epochs, j + 1, totalIter));
+				getWorkerName(), globalParams.getDataSize(), i + 1, _epochs, j + 1, totalIter));
 		}
 		return globalParams;
 	}
@@ -129,17 +137,17 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 				ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL);
 			}
 			if (LOG.isDebugEnabled()) {
-				LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1));
+				LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1));
 			}
 		}
 	}
 
 	private ListObject pullModel() {
 		// Pull the global parameters from ps
-		ListObject globalParams = (ListObject)_ps.pull(_workerID);
+		ListObject globalParams = _ps.pull(_workerID);
 		if (LOG.isDebugEnabled()) {
-			LOG.debug(String.format("Local worker_%d: Successfully pull the global parameters "
-				+ "[size:%d kb] from ps.", _workerID, globalParams.getDataSize() / 1024));
+			LOG.debug(String.format("%s: successfully pull the global parameters "
+				+ "[size:%d kb] from ps.", getWorkerName(), globalParams.getDataSize() / 1024));
 		}
 		return globalParams;
 	}
@@ -148,8 +156,8 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 		// Push the gradients to ps
 		_ps.push(_workerID, gradients);
 		if (LOG.isDebugEnabled()) {
-			LOG.debug(String.format("Local worker_%d: Successfully push the gradients "
-				+ "[size:%d kb] to ps.", _workerID, gradients.getDataSize() / 1024));
+			LOG.debug(String.format("%s: successfully push the gradients "
+				+ "[size:%d kb] to ps.", getWorkerName(), gradients.getDataSize() / 1024));
 		}
 	}
 
@@ -168,8 +176,8 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 		_ec.setVariable(Statement.PS_LABELS, bLabels);
 
 		if (LOG.isDebugEnabled()) {
-			LOG.debug(String.format("Local worker_%d: Got batch data [size:%d kb] of index from %d to %d [last index: %d]. "
-				+ "[Epoch:%d  Total epoch:%d  Iteration:%d  Total iteration:%d]", _workerID,
+			LOG.debug(String.format("%s: got batch data [size:%d kb] of index from %d to %d [last index: %d]. "
+				+ "[Epoch:%d  Total epoch:%d  Iteration:%d  Total iteration:%d]", getWorkerName(),
 				bFeatures.getDataSize() / 1024 + bLabels.getDataSize() / 1024, begin, end, dataSize, i + 1, _epochs,
 				j + 1, totalIter));
 		}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
index 52372c9..0c73acb 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -22,11 +22,14 @@ package org.apache.sysml.runtime.controlprogram.paramserv;
 import org.apache.sysml.parser.Statement;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
 
 public class LocalParamServer extends ParamServer {
 
+	public LocalParamServer() {
+		super();
+	}
+
 	public LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
 		super(model, aggFunc, updateType, ec, workerNum);
 	}
@@ -37,7 +40,7 @@ public class LocalParamServer extends ParamServer {
 	}
 
 	@Override
-	public Data pull(int workerID) {
+	public ListObject pull(int workerID) {
 		ListObject model;
 		try {
 			model = _modelMap.get(workerID).take();

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
index 1ab5f5e..464db9b 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -21,6 +21,7 @@ package org.apache.sysml.runtime.controlprogram.paramserv;
 
 import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.PS_FUNC_PREFIX;
 
+import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.stream.Collectors;
 
@@ -34,7 +35,10 @@ import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
 
-public abstract class PSWorker {
+public abstract class PSWorker implements Serializable {
+
+	private static final long serialVersionUID = -3510485051178200118L;
+
 	protected int _workerID;
 	protected int _epochs;
 	protected long _batchSize;
@@ -50,10 +54,8 @@ public abstract class PSWorker {
 	protected String _updFunc;
 	protected Statement.PSFrequency _freq;
 
-	protected PSWorker() {
+	protected PSWorker() {}
 
-	}
-	
 	protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
 		MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) {
 		_workerID = workerID;
@@ -65,7 +67,10 @@ public abstract class PSWorker {
 		_valLabels = valLabels;
 		_ec = ec;
 		_ps = ps;
+		setupUpdateFunction(updFunc, ec);
+	}
 
+	protected void setupUpdateFunction(String updFunc, ExecutionContext ec) {
 		// Get the update function
 		String[] cfn = ParamservUtils.getCompleteFuncName(updFunc, PS_FUNC_PREFIX);
 		String ns = cfn[0];
@@ -125,4 +130,6 @@ public abstract class PSWorker {
 	public MatrixObject getLabels() {
 		return _labels;
 	}
+
+	public abstract String getWorkerName();
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
index bd8ee36..2607036 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
@@ -42,7 +42,6 @@ import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
-import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
 import org.apache.sysml.utils.Statistics;
@@ -53,17 +52,19 @@ public abstract class ParamServer
 	protected static final boolean ACCRUE_BSP_GRADIENTS = true;
 	
 	// worker input queues and global model
-	protected final Map<Integer, BlockingQueue<ListObject>> _modelMap;
+	protected Map<Integer, BlockingQueue<ListObject>> _modelMap;
 	private ListObject _model;
 
 	//aggregation service
-	protected final ExecutionContext _ec;
-	private final Statement.PSUpdateType _updateType;
-	private final FunctionCallCPInstruction _inst;
-	private final String _outputName;
-	private final boolean[] _finishedStates;  // Workers' finished states
+	protected ExecutionContext _ec;
+	private Statement.PSUpdateType _updateType;
+	private FunctionCallCPInstruction _inst;
+	private String _outputName;
+	private boolean[] _finishedStates;  // Workers' finished states
 	private ListObject _accGradients = null;
 
+	protected ParamServer() {}
+
 	protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
 		// init worker queues and global model
 		_modelMap = new HashMap<>(workerNum);
@@ -77,10 +78,22 @@ public abstract class ParamServer
 		_ec = ec;
 		_updateType = updateType;
 		_finishedStates = new boolean[workerNum];
+		setupAggFunc(_ec, aggFunc);
+		
+		// broadcast initial model
+		try {
+			broadcastModel();
+		}
+		catch (InterruptedException e) {
+			throw new DMLRuntimeException("Param server: failed to broadcast the initial model.", e);
+		}
+	}
+
+	public void setupAggFunc(ExecutionContext ec, String aggFunc) {
 		String[] cfn = ParamservUtils.getCompleteFuncName(aggFunc, PS_FUNC_PREFIX);
 		String ns = cfn[0];
 		String fname = cfn[1];
-		FunctionProgramBlock func = _ec.getProgram().getFunctionProgramBlock(ns, fname);
+		FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(ns, fname);
 		ArrayList<DataIdentifier> inputs = func.getInputParams();
 		ArrayList<DataIdentifier> outputs = func.getOutputParams();
 
@@ -101,19 +114,11 @@ public abstract class ParamServer
 		ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
 			.collect(Collectors.toCollection(ArrayList::new));
 		_inst = new FunctionCallCPInstruction(ns, fname, boundInputs, inputNames, outputNames, "aggregate function");
-		
-		// broadcast initial model
-		try {
-			broadcastModel();
-		}
-		catch (InterruptedException e) {
-			throw new DMLRuntimeException("Param server: failed to broadcast the initial model.", e);
-		}
 	}
 
 	public abstract void push(int workerID, ListObject value);
 
-	public abstract Data pull(int workerID);
+	public abstract ListObject pull(int workerID);
 
 	public ListObject getResult() {
 		// All the model updating work has terminated,

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
index b9fd7a8..cf27457 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -28,8 +28,11 @@ import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
 import org.apache.commons.lang.StringUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.spark.Partitioner;
 import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.MultiThreadedHop;
@@ -57,6 +60,7 @@ import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkAggregator;
 import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkMapper;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.functionobjects.Plus;
 import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
@@ -68,13 +72,14 @@ import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.matrix.data.OutputInfo;
 import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysml.runtime.util.ProgramConverter;
+import org.apache.sysml.utils.Statistics;
 
 import scala.Tuple2;
 
 public class ParamservUtils {
 
+	protected static final Log LOG = LogFactory.getLog(ParamservUtils.class.getName());
 	public static final String PS_FUNC_PREFIX = "_ps_";
-
 	public static long SEED = -1; // Used for generating permutation
 
 	/**
@@ -140,6 +145,14 @@ public class ParamservUtils {
 		CacheableData<?> cd = (CacheableData<?>) data;
 		cd.enableCleanup(true);
 		ec.cleanupCacheableData(cd);
+		if (LOG.isDebugEnabled()) {
+			LOG.debug(String.format("%s has been deleted.", cd.getFileName()));
+		}
+	}
+
+	public static void cleanupMatrixObject(ExecutionContext ec, MatrixObject mo) {
+		mo.enableCleanup(true);
+		ec.cleanupCacheableData(mo);
 	}
 
 	public static MatrixObject newMatrixObject(MatrixBlock mb) {
@@ -365,6 +378,7 @@ public class ParamservUtils {
 
 	@SuppressWarnings("unchecked")
 	public static JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> doPartitionOnSpark(SparkExecutionContext sec, MatrixObject features, MatrixObject labels, Statement.PSScheme scheme, int workerNum) {
+		Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
 		// Get input RDD
 		JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD = (JavaPairRDD<MatrixIndexes, MatrixBlock>)
 				sec.getRDDHandleForMatrixObject(features, InputInfo.BinaryBlockInputInfo);
@@ -372,33 +386,34 @@ public class ParamservUtils {
 				sec.getRDDHandleForMatrixObject(labels, InputInfo.BinaryBlockInputInfo);
 
 		DataPartitionerSparkMapper mapper = new DataPartitionerSparkMapper(scheme, workerNum, sec, (int) features.getNumRows());
-		return ParamservUtils.assembleTrainingData(features.getNumRows(), featuresRDD, labelsRDD) // Combine features and labels into a pair (rowBlockID => (features, labels))
+		JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> result = ParamservUtils
+			.assembleTrainingData(features.getNumRows(), featuresRDD, labelsRDD) // Combine features and labels into a pair (rowBlockID => (features, labels))
 			.flatMapToPair(mapper) // Do the data partitioning on spark (workerID => (rowBlockID, (single row features, single row labels))
 			// Aggregate the partitioned matrix according to rowID for each worker
 			// i.e. (workerID => ordered list[(rowBlockID, (single row features, single row labels)]
-			.aggregateByKey(new LinkedList<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>(),
-				new Partitioner() {
-					private static final long serialVersionUID = -7937781374718031224L;
-					@Override
-					public int getPartition(Object workerID) {
-						return (int) workerID;
-					}
-					@Override
-					public int numPartitions() {
-						return workerNum;
-					}
-				}, 
-				(list, input) -> {
-					list.add(input);
-					return list;
-				},
-				(l1, l2) -> {
-					l1.addAll(l2);
-					l1.sort((o1, o2) -> o1._1.compareTo(o2._1));
-					return l1;
-				})
-			.mapToPair(new DataPartitionerSparkAggregator(
-				features.getNumColumns(), labels.getNumColumns()));
+			.aggregateByKey(new LinkedList<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>(), new Partitioner() {
+				private static final long serialVersionUID = -7937781374718031224L;
+				@Override
+				public int getPartition(Object workerID) {
+					return (int) workerID;
+				}
+				@Override
+				public int numPartitions() {
+					return workerNum;
+				}
+			}, (list, input) -> {
+				list.add(input);
+				return list;
+			}, (l1, l2) -> {
+				l1.addAll(l2);
+				l1.sort((o1, o2) -> o1._1.compareTo(o2._1));
+				return l1;
+			})
+			.mapToPair(new DataPartitionerSparkAggregator(features.getNumColumns(), labels.getNumColumns()));
+
+		if (DMLScript.STATISTICS)
+			Statistics.accPSSetupTime((long) tSetup.stop());
+		return result;
 	}
 
 	public static ListObject accrueGradients(ListObject accGradients, ListObject gradients) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
index ec10232..9354025 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
@@ -28,12 +28,10 @@ public class SparkPSBody {
 
 	private ExecutionContext _ec;
 
-	public SparkPSBody() {
-
-	}
+	public SparkPSBody() {}
 
 	public SparkPSBody(ExecutionContext ec) {
-		this._ec = ec;
+		_ec = ec;
 	}
 
 	public ExecutionContext getEc() {

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
new file mode 100644
index 0000000..de7b6c6
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark;
+
+import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PULL;
+import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PUSH;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.utils.Statistics;
+
+public class SparkPSProxy extends ParamServer {
+
+	private TransportClient _client;
+	private final long _rpcTimeout;
+
+	public SparkPSProxy(TransportClient client, long rpcTimeout) {
+		super();
+		_client = client;
+		_rpcTimeout = rpcTimeout;
+	}
+
+	@Override
+	public void push(int workerID, ListObject value) {
+		Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
+		PSRpcResponse response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PUSH, workerID, value).serialize(), _rpcTimeout));
+		if (DMLScript.STATISTICS)
+			Statistics.accPSRpcRequestTime((long) tRpc.stop());
+		if (!response.isSuccessful()) {
+			throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients. \n%s", workerID, response.getErrorMessage()));
+		}
+	}
+
+	@Override
+	public ListObject pull(int workerID) {
+		Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
+		PSRpcResponse response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PULL, workerID, null).serialize(), _rpcTimeout));
+		if (DMLScript.STATISTICS)
+			Statistics.accPSRpcRequestTime((long) tRpc.stop());
+		if (!response.isSuccessful()) {
+			throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models. \n%s", workerID, response.getErrorMessage()));
+		}
+		return response.getResultModel();
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
index 466801f..fa06243 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
@@ -20,43 +20,58 @@
 package org.apache.sysml.runtime.controlprogram.paramserv.spark;
 
 import java.io.IOException;
-import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 
 import org.apache.spark.api.java.function.VoidFunction;
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.parser.Statement;
 import org.apache.sysml.runtime.codegen.CodegenUtils;
-import org.apache.sysml.runtime.controlprogram.paramserv.PSWorker;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcFactory;
 import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.util.ProgramConverter;
+import org.apache.sysml.utils.Statistics;
 
 import scala.Tuple2;
 
-public class SparkPSWorker extends PSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>>, Serializable {
+public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>> {
 
 	private static final long serialVersionUID = -8674739573419648732L;
 
 	private String _program;
 	private HashMap<String, byte[]> _clsMap;
+	private String _host; // host ip of driver
+	private long _rpcTimeout; // rpc ask timeout
+	private String _aggFunc;
 
-	protected SparkPSWorker() {
-		// No-args constructor used for deserialization
-	}
-
-	public SparkPSWorker(String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap) {
+	public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, String host, long rpcTimeout) {
 		_updFunc = updFunc;
+		_aggFunc = aggFunc;
 		_freq = freq;
 		_epochs = epochs;
 		_batchSize = batchSize;
 		_program = program;
 		_clsMap = clsMap;
+		_host = host;
+		_rpcTimeout = rpcTimeout;
+	}
+
+	@Override
+	public String getWorkerName() {
+		return String.format("Spark worker_%d", _workerID);
 	}
 
 	@Override
 	public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws Exception {
+		Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
 		configureWorker(input);
+		if (DMLScript.STATISTICS)
+			Statistics.accPSSetupTime((long) tSetup.stop());
+		call(); // Launch the worker
 	}
 
 	private void configureWorker(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws IOException {
@@ -73,5 +88,20 @@ public class SparkPSWorker extends PSWorker implements VoidFunction<Tuple2<Integ
 
 		// Initialize the buffer pool and register it in the jvm shutdown hook in order to be cleanuped at the end
 		RemoteParForUtils.setupBufferPool(_workerID);
+
+		// Create the ps proxy
+		_ps = PSRpcFactory.createSparkPSProxy(_host, _rpcTimeout);
+
+		// Initialize the update function
+		setupUpdateFunction(_updFunc, _ec);
+
+		// Initialize the agg function
+		_ps.setupAggFunc(_ec, _aggFunc);
+
+		// Lazy initialize the matrix of features and labels
+		setFeatures(ParamservUtils.newMatrixObject(input._2._1));
+		setLabels(ParamservUtils.newMatrixObject(input._2._2));
+		_features.enableCleanup(false);
+		_labels.enableCleanup(false);
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
new file mode 100644
index 0000000..999d409
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_BEGIN;
+import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_END;
+import static org.apache.sysml.runtime.util.ProgramConverter.COMPONENTS_DELIM;
+import static org.apache.sysml.runtime.util.ProgramConverter.EMPTY;
+import static org.apache.sysml.runtime.util.ProgramConverter.LEVELIN;
+import static org.apache.sysml.runtime.util.ProgramConverter.LEVELOUT;
+
+import java.nio.ByteBuffer;
+import java.util.StringTokenizer;
+
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.util.ProgramConverter;
+
+public class PSRpcCall extends PSRpcObject {
+
+	private static final String PS_RPC_CALL_BEGIN = CDATA_BEGIN + "PSRPCCALL" + LEVELIN;
+	private static final String PS_RPC_CALL_END = LEVELOUT + CDATA_END;
+
+	private String _method;
+	private int _workerID;
+	private ListObject _data;
+
+	public PSRpcCall(String method, int workerID, ListObject data) {
+		_method = method;
+		_workerID = workerID;
+		_data = data;
+	}
+
+	public PSRpcCall(ByteBuffer buffer) {
+		deserialize(buffer);
+	}
+
+	public void deserialize(ByteBuffer buffer) {
+		//FIXME: instead of shallow deserialize + read, we should do a deep deserialize of the matrix blocks.
+		String input = bufferToString(buffer);
+		//header elimination
+		input = input.substring(PS_RPC_CALL_BEGIN.length(), input.length() - PS_RPC_CALL_END.length()); //remove start/end
+		StringTokenizer st = new StringTokenizer(input, COMPONENTS_DELIM);
+
+		_method = st.nextToken();
+		_workerID = Integer.valueOf(st.nextToken());
+		String dataStr = st.nextToken();
+		_data = dataStr.equals(EMPTY) ? null :
+			(ListObject) ProgramConverter.parseDataObject(dataStr)[1];
+	}
+
+	public ByteBuffer serialize() {
+		//FIXME: instead of export+shallow serialize, we should do a deep serialize of the matrix blocks.
+		StringBuilder sb = new StringBuilder();
+		sb.append(PS_RPC_CALL_BEGIN);
+		sb.append(_method);
+		sb.append(COMPONENTS_DELIM);
+		sb.append(_workerID);
+		sb.append(COMPONENTS_DELIM);
+		if (_data == null) {
+			sb.append(EMPTY);
+		} else {
+			flushListObject(_data);
+			sb.append(ProgramConverter.serializeDataObject(DATA_KEY, _data));
+		}
+		sb.append(PS_RPC_CALL_END);
+		return ByteBuffer.wrap(sb.toString().getBytes());
+	}
+
+	public String getMethod() {
+		return _method;
+	}
+
+	public int getWorkerID() {
+		return _workerID;
+	}
+
+	public ListObject getData() {
+		return _data;
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
new file mode 100644
index 0000000..c8b4024
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import java.io.IOException;
+import java.util.Collections;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSProxy;
+
+//TODO should be able to configure the port by users
+public class PSRpcFactory {
+
+	private static final int PORT = 5055;
+	private static final String MODULE_NAME = "ps";
+
+	private static TransportContext createTransportContext(LocalParamServer ps) {
+		TransportConf conf = new TransportConf(MODULE_NAME, new SystemPropertyConfigProvider());
+		PSRpcHandler handler = new PSRpcHandler(ps);
+		return new TransportContext(conf, handler);
+	}
+
+	/**
+	 * Create and start the server
+	 * @return server
+	 */
+	public static TransportServer createServer(LocalParamServer ps, String host) {
+		TransportContext context = createTransportContext(ps);
+		return context.createServer(host, PORT, Collections.emptyList());
+	}
+
+	public static SparkPSProxy createSparkPSProxy(String host, long rpcTimeout) throws IOException {
+		TransportContext context = createTransportContext(new LocalParamServer());
+		return new SparkPSProxy(context.createClientFactory().createClient(host, PORT), rpcTimeout);
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
new file mode 100644
index 0000000..3d73a37
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PULL;
+import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PUSH;
+import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.EMPTY_DATA;
+import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.ERROR;
+import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.SUCCESS;
+
+import java.nio.ByteBuffer;
+
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public final class PSRpcHandler extends RpcHandler {
+
+	private LocalParamServer _server;
+
+	protected PSRpcHandler(LocalParamServer server) {
+		_server = server;
+	}
+
+	@Override
+	public void receive(TransportClient client, ByteBuffer buffer, RpcResponseCallback callback) {
+		PSRpcCall call = new PSRpcCall(buffer);
+		PSRpcResponse response = null;
+		switch (call.getMethod()) {
+			case PUSH:
+				try {
+					_server.push(call.getWorkerID(), call.getData());
+					response = new PSRpcResponse(SUCCESS, EMPTY_DATA);
+				} catch (DMLRuntimeException exception) {
+					response = new PSRpcResponse(ERROR, ExceptionUtils.getFullStackTrace(exception));
+				} finally {
+					callback.onSuccess(response.serialize());
+				}
+				break;
+			case PULL:
+				ListObject data;
+				try {
+					data = _server.pull(call.getWorkerID());
+					response = new PSRpcResponse(SUCCESS, data);
+				} catch (DMLRuntimeException exception) {
+					response = new PSRpcResponse(ERROR, ExceptionUtils.getFullStackTrace(exception));
+				} finally {
+					callback.onSuccess(response.serialize());
+				}
+				break;
+			default:
+				throw new DMLRuntimeException(String.format("Does not support the rpc call for method %s", call.getMethod()));
+		}
+	}
+
+	@Override
+	public StreamManager getStreamManager() {
+		return new OneForOneStreamManager();
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
new file mode 100644
index 0000000..c6d7fd3
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import java.nio.ByteBuffer;
+
+import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public abstract class PSRpcObject {
+
+	public static final String PUSH = "push";
+	public static final String PULL = "pull";
+	public static final String DATA_KEY = "data";
+	public static final String EMPTY_DATA = "";
+
+	public abstract void deserialize(ByteBuffer buffer);
+
+	public abstract ByteBuffer serialize();
+
+	/**
+	 * Convert direct byte buffer to string
+	 * @param buffer direct byte buffer
+	 * @return string
+	 */
+	protected String bufferToString(ByteBuffer buffer) {
+		byte[] result = new byte[buffer.limit()];
+		buffer.get(result, 0, buffer.limit());
+		return new String(result);
+	}
+
+	/**
+	 * Flush the data into HDFS
+	 * @param data list object
+	 */
+	protected void flushListObject(ListObject data) {
+		data.getData().stream().filter(d -> d instanceof CacheableData)
+			.forEach(d -> ((CacheableData<?>) d).exportData());
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
new file mode 100644
index 0000000..998c523
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_BEGIN;
+import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_END;
+import static org.apache.sysml.runtime.util.ProgramConverter.COMPONENTS_DELIM;
+import static org.apache.sysml.runtime.util.ProgramConverter.EMPTY;
+import static org.apache.sysml.runtime.util.ProgramConverter.LEVELIN;
+import static org.apache.sysml.runtime.util.ProgramConverter.LEVELOUT;
+
+import java.nio.ByteBuffer;
+import java.util.StringTokenizer;
+
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.util.ProgramConverter;
+
+public class PSRpcResponse extends PSRpcObject {
+
+	public static final int SUCCESS = 1;
+	public static final int ERROR = 2;
+
+	private static final String PS_RPC_RESPONSE_BEGIN = CDATA_BEGIN + "PSRPCRESPONSE" + LEVELIN;
+	private static final String PS_RPC_RESPONSE_END = LEVELOUT + CDATA_END;
+
+	private int _status;
+	private Object _data;	// Could be list object or exception
+
+	public PSRpcResponse(ByteBuffer buffer) {
+		deserialize(buffer);
+	}
+
+	public PSRpcResponse(int status, Object data) {
+		_status = status;
+		_data = data;
+	}
+
+	public boolean isSuccessful() {
+		return _status == SUCCESS;
+	}
+
+	public String getErrorMessage() {
+		return (String) _data;
+	}
+
+	public ListObject getResultModel() {
+		return (ListObject) _data;
+	}
+
+	@Override
+	public void deserialize(ByteBuffer buffer) {
+		//FIXME: instead of shallow deserialize + read, we should do a deep deserialize of the matrix blocks.
+		String input = bufferToString(buffer);
+		//header elimination
+		input = input.substring(PS_RPC_RESPONSE_BEGIN.length(), input.length() - PS_RPC_RESPONSE_END.length()); //remove start/end
+		StringTokenizer st = new StringTokenizer(input, COMPONENTS_DELIM);
+
+		_status = Integer.valueOf(st.nextToken());
+		String data = st.nextToken();
+		switch (_status) {
+			case SUCCESS:
+				_data = data.equals(EMPTY) ? null :
+					ProgramConverter.parseDataObject(data)[1];
+				break;
+			case ERROR:
+				_data = data;
+				break;
+		}
+	}
+
+	@Override
+	public ByteBuffer serialize() {
+		//FIXME: instead of export+shallow serialize, we should do a deep serialize of the matrix blocks.
+		
+		StringBuilder sb = new StringBuilder();
+		sb.append(PS_RPC_RESPONSE_BEGIN);
+		sb.append(_status);
+		sb.append(COMPONENTS_DELIM);
+		switch (_status) {
+			case SUCCESS:
+				if (_data.equals(EMPTY_DATA)) {
+					sb.append(EMPTY);
+				} else {
+					flushListObject((ListObject) _data);
+					sb.append(ProgramConverter.serializeDataObject(DATA_KEY, (ListObject) _data));
+				}
+				break;
+			case ERROR:
+				sb.append(_data.toString());
+				break;
+		}
+		sb.append(PS_RPC_RESPONSE_END);
+		return ByteBuffer.wrap(sb.toString().getBytes());
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 4e7a718..6133987 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -55,6 +55,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.log4j.Level;
 import org.apache.log4j.Logger;
+import org.apache.spark.network.server.TransportServer;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.recompile.Recompiler;
 import org.apache.sysml.lops.LopProperties;
@@ -71,6 +72,7 @@ import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
 import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSBody;
 import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSWorker;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcFactory;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.matrix.operators.Operator;
@@ -114,16 +116,16 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 	}
 
 	private void runOnSpark(SparkExecutionContext sec, PSModeType mode) {
+		Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
+
 		PSScheme scheme = getScheme();
 		int workerNum = getWorkerNum(mode);
 		String updFunc = getParam(PS_UPDATE_FUN);
 		String aggFunc = getParam(PS_AGGREGATION_FUN);
 
-		int k = getParLevel(workerNum);
-
 		// Get the compiled execution context
 		LocalVariableMap newVarsMap = createVarsMap(sec);
-		ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, k);
+		ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1); // level of par is 1 in spark backend
 
 		MatrixObject features = sec.getMatrixObject(getParam(PS_FEATURES));
 		MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS));
@@ -131,16 +133,47 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		// Force all the instructions to CP type
 		Recompiler.recompileProgramBlockHierarchy2Forced(
 			newEC.getProgram().getProgramBlocks(), 0, new HashSet<>(), LopProperties.ExecType.CP);
-		
+
 		// Serialize all the needed params for remote workers
 		SparkPSBody body = new SparkPSBody(newEC);
 		HashMap<String, byte[]> clsMap = new HashMap<>();
 		String program = ProgramConverter.serializeSparkPSBody(body, clsMap);
 
-		SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getFrequency(), getEpochs(), getBatchSize(), program, clsMap);
-		ParamservUtils.doPartitionOnSpark(sec, features, labels, scheme, workerNum) // Do data partitioning
-			.foreach(worker);   // Run remote workers
+		// Get some configurations
+		String host = sec.getSparkContext().getConf().get("spark.driver.host");
+		long rpcTimeout = sec.getSparkContext().getConf().contains("spark.rpc.askTimeout") ? 
+			sec.getSparkContext().getConf().getTimeAsMs("spark.rpc.askTimeout") :
+			sec.getSparkContext().getConf().getTimeAsMs("spark.network.timeout", "120s");
+
+		// Create remote workers
+		SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), getFrequency(),
+			getEpochs(), getBatchSize(), program, clsMap, host, rpcTimeout);
+
+		// Create the agg service's execution context
+		ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0);
+
+		// Create the parameter server
+		ListObject model = sec.getListObject(getParam(PS_MODEL));
+		ParamServer ps = createPS(mode, aggFunc, getUpdateType(), workerNum, model, aggServiceEC);
+
+		if (DMLScript.STATISTICS)
+			Statistics.accPSSetupTime((long) tSetup.stop());
+
+		// Create the netty server for ps
+		TransportServer server = PSRpcFactory.createServer((LocalParamServer) ps, host); // Start the server
 
+		try {
+			ParamservUtils.doPartitionOnSpark(sec, features, labels, scheme, workerNum) // Do data partitioning
+				.foreach(worker); // Run remote workers
+		} catch (Exception e) {
+			throw new DMLRuntimeException("Paramserv function failed: ", e);
+		} finally {
+			// Stop the netty server
+			server.close();
+		}
+
+		// Fetch the final model from ps
+		sec.setVariable(output.getName(), ps.getResult());
 	}
 
 	private void runLocally(ExecutionContext ec, PSModeType mode) {
@@ -176,8 +209,8 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		MatrixObject valFeatures = ec.getMatrixObject(getParam(PS_VAL_FEATURES));
 		MatrixObject valLabels = ec.getMatrixObject(getParam(PS_VAL_LABELS));
 		List<LocalPSWorker> workers = IntStream.range(0, workerNum)
-		   .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, epochs, getBatchSize(), valFeatures, valLabels, workerECs.get(i), ps))
-		   .collect(Collectors.toList());
+			.mapToObj(i -> new LocalPSWorker(i, updFunc, freq, epochs, getBatchSize(), valFeatures, valLabels, workerECs.get(i), ps))
+			.collect(Collectors.toList());
 
 		// Do data partition
 		PSScheme scheme = getScheme();
@@ -296,6 +329,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 	private ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType, int workerNum, ListObject model, ExecutionContext ec) {
 		switch (mode) {
 			case LOCAL:
+			case REMOTE_SPARK:
 				return new LocalParamServer(model, aggFunc, updateType, ec, workerNum);
 			default:
 				throw new DMLRuntimeException("Unsupported parameter server: "+mode.name());

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java b/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
index 1d2115e..fc9d9b4 100644
--- a/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
@@ -143,7 +143,7 @@ public class ProgramConverter
 	public static final String PB_IF = " IF" + LEVELIN;
 	public static final String PB_FC = " FC" + LEVELIN;
 	public static final String PB_EFC = " EFC" + LEVELIN;
-	
+
 	public static final String CONF_STATS = "stats";
 
 	// Used for parfor
@@ -716,9 +716,10 @@ public class ProgramConverter
 		builder.append(rSerializeProgramBlocks(ec.getProgram().getProgramBlocks(), clsMap));
 		builder.append(PBS_END);
 		builder.append(NEWLINE);
+		builder.append(COMPONENTS_DELIM);
+		builder.append(NEWLINE);
 
 		builder.append(PSBODY_END);
-
 		return builder.toString();
 	}
 
@@ -868,7 +869,7 @@ public class ProgramConverter
 				value = mo.getFileName();
 				PartitionFormat partFormat = (mo.getPartitionFormat()!=null) ? new PartitionFormat(
 						mo.getPartitionFormat(),mo.getPartitionSize()) : PartitionFormat.NONE;
-				metaData = new String[9];
+				metaData = new String[11];
 				metaData[0] = String.valueOf( mc.getRows() );
 				metaData[1] = String.valueOf( mc.getCols() );
 				metaData[2] = String.valueOf( mc.getRowsPerBlock() );
@@ -878,6 +879,8 @@ public class ProgramConverter
 				metaData[6] = OutputInfo.outputInfoToString( md.getOutputInfo() );
 				metaData[7] = String.valueOf( partFormat );
 				metaData[8] = String.valueOf( mo.getUpdateType() );
+				metaData[9] = String.valueOf(mo.isHDFSFileExists());
+				metaData[10] = String.valueOf(mo.isCleanupEnabled());
 				break;
 			case LIST:
 				// SCHEMA: <name>|<datatype>|<valuetype>|value|<metadata>|<tab>element1<tab>element2<tab>element3 (this is the list)
@@ -1683,6 +1686,8 @@ public class ProgramConverter
 				if( partFormat._dpf != PDataPartitionFormat.NONE )
 					mo.setPartitioned( partFormat._dpf, partFormat._N );
 				mo.setUpdateType(inplace);
+				mo.setHDFSFileExists(Boolean.valueOf(st.nextToken()));
+				mo.enableCleanup(Boolean.valueOf(st.nextToken()));
 				dat = mo;
 				break;
 			}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/utils/Statistics.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java
index 8f0d853..1dd8362 100644
--- a/src/main/java/org/apache/sysml/utils/Statistics.java
+++ b/src/main/java/org/apache/sysml/utils/Statistics.java
@@ -125,6 +125,7 @@ public class Statistics
 	private static final LongAdder psLocalModelUpdateTime = new LongAdder();
 	private static final LongAdder psModelBroadcastTime = new LongAdder();
 	private static final LongAdder psBatchIndexTime = new LongAdder();
+	private static final LongAdder psRpcRequestTime = new LongAdder();
 
 	//PARFOR optimization stats (low frequency updates)
 	private static long parforOptTime = 0; //in milli sec
@@ -564,6 +565,10 @@ public class Statistics
 		psBatchIndexTime.add(t);
 	}
 
+	public static void accPSRpcRequestTime(long t) {
+		psRpcRequestTime.add(t);
+	}
+
 	public static String getCPHeavyHitterCode( Instruction inst )
 	{
 		String opcode = null;
@@ -1003,6 +1008,7 @@ public class Statistics
 						psLocalModelUpdateTime.doubleValue() / 1000, psAggregationTime.doubleValue() / 1000));
 				sb.append(String.format("Paramserv model broadcast time:\t%.3f secs.\n", psModelBroadcastTime.doubleValue() / 1000));
 				sb.append(String.format("Paramserv batch slice time:\t%.3f secs.\n", psBatchIndexTime.doubleValue() / 1000));
+				sb.append(String.format("Paramserv RPC request time:\t%.3f secs.\n", psRpcRequestTime.doubleValue() / 1000));
 			}
 			if( parforOptCount>0 ){
 				sb.append("ParFor loops optimized:\t\t" + getParforOptCount() + ".\n");

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java
index d5fd509..905bfd1 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java
@@ -19,75 +19,66 @@
 
 package org.apache.sysml.test.integration.functions.paramserv;
 
+import org.apache.sysml.parser.Statement;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.junit.Test;
 
 public class ParamservLocalNNTest extends AutomatedTestBase {
 
-	private static final String TEST_NAME1 = "paramserv-nn-bsp-batch-dc";
-	private static final String TEST_NAME2 = "paramserv-nn-asp-batch";
-	private static final String TEST_NAME3 = "paramserv-nn-bsp-epoch";
-	private static final String TEST_NAME4 = "paramserv-nn-asp-epoch";
-	private static final String TEST_NAME5 = "paramserv-nn-bsp-batch-drr";
-	private static final String TEST_NAME6 = "paramserv-nn-bsp-batch-dr";
-	private static final String TEST_NAME7 = "paramserv-nn-bsp-batch-or";
+	private static final String TEST_NAME = "paramserv-test";
 
 	private static final String TEST_DIR = "functions/paramserv/";
 	private static final String TEST_CLASS_DIR = TEST_DIR + ParamservLocalNNTest.class.getSimpleName() + "/";
 
 	@Override
 	public void setUp() {
-		addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {}));
-		addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {}));
-		addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {}));
-		addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {}));
-		addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {}));
-		addTestConfiguration(TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {}));
-		addTestConfiguration(TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {}));
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {}));
 	}
 
 	@Test
 	public void testParamservBSPBatchDisjointContiguous() {
-		runDMLTest(TEST_NAME1);
+		runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
 	}
 
 	@Test
 	public void testParamservASPBatch() {
-		runDMLTest(TEST_NAME2);
+		runDMLTest(10, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
 	}
 
 	@Test
 	public void testParamservBSPEpoch() {
-		runDMLTest(TEST_NAME3);
+		runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
 	}
 
 	@Test
 	public void testParamservASPEpoch() {
-		runDMLTest(TEST_NAME4);
+		runDMLTest(10, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
 	}
 
 	@Test
 	public void testParamservBSPBatchDisjointRoundRobin() {
-		runDMLTest(TEST_NAME5);
+		runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_ROUND_ROBIN);
 	}
 
 	@Test
 	public void testParamservBSPBatchDisjointRandom() {
-		runDMLTest(TEST_NAME6);
+		runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_RANDOM);
 	}
 
 	@Test
 	public void testParamservBSPBatchOverlapReshuffle() {
-		runDMLTest(TEST_NAME7);
+		runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.OVERLAP_RESHUFFLE);
 	}
 
-	private void runDMLTest(String testname) {
-		TestConfiguration config = getTestConfiguration(testname);
+	private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme) {
+		TestConfiguration config = getTestConfiguration(ParamservLocalNNTest.TEST_NAME);
 		loadTestConfiguration(config);
-		programArgs = new String[] { "-explain" };
+		programArgs = new String[] { "-explain", "-nvargs", "mode=LOCAL", "epochs=" + epochs,
+			"workers=" + workers, "utype=" + utype, "freq=" + freq, "batchsize=" + batchsize,
+			"scheme=" + scheme };
 		String HOME = SCRIPT_DIR + TEST_DIR;
-		fullDMLScriptName = HOME + testname + ".dml";
+		fullDMLScriptName = HOME + ParamservLocalNNTest.TEST_NAME + ".dml";
 		runTest(true, false, null, null, -1);
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
index 2441116..30eccb3 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
@@ -1,14 +1,24 @@
 package org.apache.sysml.test.integration.functions.paramserv;
 
+import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.sysml.api.DMLException;
 import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.Script;
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.junit.Test;
 
 public class ParamservSparkNNTest extends AutomatedTestBase {
 
-	private static final String TEST_NAME1 = "paramserv-spark-nn-bsp-batch-dc";
+	private static final String TEST_NAME1 = "paramserv-test";
+	private static final String TEST_NAME2 = "paramserv-spark-worker-failed";
+	private static final String TEST_NAME3 = "paramserv-spark-agg-service-failed";
 
 	private static final String TEST_DIR = "functions/paramserv/";
 	private static final String TEST_CLASS_DIR = TEST_DIR + ParamservSparkNNTest.class.getSimpleName() + "/";
@@ -16,14 +26,42 @@ public class ParamservSparkNNTest extends AutomatedTestBase {
 	@Override
 	public void setUp() {
 		addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {}));
+		addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {}));
+		addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {}));
 	}
 
 	@Test
 	public void testParamservBSPBatchDisjointContiguous() {
-		runDMLTest(TEST_NAME1);
+		runDMLTest(2, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+	}
+
+	@Test
+	public void testParamservASPBatchDisjointContiguous() {
+		runDMLTest(2, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.BATCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+	}
+
+	@Test
+	public void testParamservBSPEpochDisjointContiguous() {
+		runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+	}
+
+	@Test
+	public void testParamservASPEpochDisjointContiguous() {
+		runDMLTest(10, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
 	}
 
-	private void runDMLTest(String testname) {
+	@Test
+	public void testParamservWorkerFailed() {
+		runDMLTest(TEST_NAME2, true, DMLException.class, "Invalid indexing by name in unnamed list: worker_err.");
+	}
+
+	@Test
+	public void testParamservAggServiceFailed() {
+		runDMLTest(TEST_NAME3, true, DMLException.class, "Invalid indexing by name in unnamed list: agg_service_err.");
+	}
+
+	private void runDMLTest(String testname, boolean exceptionExpected, Class<?> expectedException, String errMessage) {
+		programArgs = new String[] { "-explain" };
 		DMLScript.RUNTIME_PLATFORM oldRtplatform = AutomatedTestBase.rtplatform;
 		boolean oldUseLocalSparkConfig = DMLScript.USE_LOCAL_SPARK_CONFIG;
 		AutomatedTestBase.rtplatform = DMLScript.RUNTIME_PLATFORM.SPARK;
@@ -32,16 +70,32 @@ public class ParamservSparkNNTest extends AutomatedTestBase {
 		try {
 			TestConfiguration config = getTestConfiguration(testname);
 			loadTestConfiguration(config);
-			programArgs = new String[] { "-explain" };
 			String HOME = SCRIPT_DIR + TEST_DIR;
 			fullDMLScriptName = HOME + testname + ".dml";
-			// The test is not already finished, so it is normal to have the NPE
-			runTest(true, true, DMLException.class, null, -1);
+			runTest(true, exceptionExpected, expectedException, errMessage, -1);
 		} finally {
 			AutomatedTestBase.rtplatform = oldRtplatform;
 			DMLScript.USE_LOCAL_SPARK_CONFIG = oldUseLocalSparkConfig;
 		}
-
 	}
 
+	private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme) {
+		Script script = dmlFromFile(SCRIPT_DIR + TEST_DIR + TEST_NAME1 + ".dml").in("$mode", Statement.PSModeType.REMOTE_SPARK.toString())
+			.in("$epochs", String.valueOf(epochs))
+			.in("$workers", String.valueOf(workers))
+			.in("$utype", utype.toString())
+			.in("$freq", freq.toString())
+			.in("$batchsize", String.valueOf(batchsize))
+			.in("$scheme", scheme.toString());
+
+		SparkConf conf = SparkExecutionContext.createSystemMLSparkConf().setAppName("ParamservSparkNNTest").setMaster("local[*]")
+			.set("spark.driver.allowMultipleContexts", "true");
+		JavaSparkContext sc = new JavaSparkContext(conf);
+		MLContext ml = new MLContext(sc);
+		ml.setStatistics(true);
+		ml.execute(script);
+		ml.resetConfig();
+		sc.stop();
+		ml.close();
+	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
new file mode 100644
index 0000000..57e1106
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.paramserv;
+
+import java.util.Arrays;
+
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse;
+import org.apache.sysml.runtime.instructions.cp.IntObject;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RpcObjectTest {
+
+	@Test
+	public void testPSRpcCall() {
+		MatrixObject mo1 = SerializationTest.generateDummyMatrix(10);
+		MatrixObject mo2 = SerializationTest.generateDummyMatrix(20);
+		IntObject io = new IntObject(30);
+		ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
+		PSRpcCall expected = new PSRpcCall(PSRpcObject.PUSH, 1, lo);
+		PSRpcCall actual = new PSRpcCall(expected.serialize());
+		Assert.assertEquals(new String(expected.serialize().array()), new String(actual.serialize().array()));
+	}
+
+	@Test
+	public void testPSRpcResponse() {
+		MatrixObject mo1 = SerializationTest.generateDummyMatrix(10);
+		MatrixObject mo2 = SerializationTest.generateDummyMatrix(20);
+		IntObject io = new IntObject(30);
+		ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
+		PSRpcResponse expected = new PSRpcResponse(PSRpcResponse.SUCCESS, lo);
+		PSRpcResponse actual = new PSRpcResponse(expected.serialize());
+		Assert.assertEquals(new String(expected.serialize().array()), new String(actual.serialize().array()));
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java
index 2a08ca6..64d6492 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java
@@ -68,7 +68,7 @@ public class SerializationTest {
 		Assert.assertEquals(io.getLongValue(), actualIO.getLongValue());
 	}
 
-	private MatrixObject generateDummyMatrix(int size) {
+	public static MatrixObject generateDummyMatrix(int size) {
 		double[] dl = new double[size];
 		for (int i = 0; i < size; i++) {
 			dl[i] = i;

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
deleted file mode 100644
index ba22942..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "BATCH", batchsize,"DISJOINT_CONTIGUOUS", "LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
deleted file mode 100644
index c8c6a2f..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "EPOCH", batchsize, "DISJOINT_CONTIGUOUS", "LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml
deleted file mode 100644
index 78fc1c4..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_CONTIGUOUS", "LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml
deleted file mode 100644
index 9191b5a..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_RANDOM", "LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml
deleted file mode 100644
index ec18cb4..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 4
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_ROUND_ROBIN", "LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml
deleted file mode 100644
index 928dde2..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "OVERLAP_RESHUFFLE", "LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file