You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/07/27 06:32:50 UTC

systemml git commit: [SYSTEMML-2420, 2457] Improved distributed paramserv comm and stats

Repository: systemml
Updated Branches:
  refs/heads/master bfd495289 -> e0c271fe4


[SYSTEMML-2420,2457] Improved distributed paramserv comm and stats

Closes #808.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e0c271fe
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e0c271fe
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e0c271fe

Branch: refs/heads/master
Commit: e0c271fe4cb05d3a53eb0143ab298e26831d1ed7
Parents: bfd4952
Author: EdgarLGB <gu...@atos.net>
Authored: Thu Jul 26 23:33:12 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Thu Jul 26 23:33:13 2018 -0700

----------------------------------------------------------------------
 .../controlprogram/caching/CacheableData.java   |  6 ++
 .../controlprogram/paramserv/LocalPSWorker.java | 48 +++++++---
 .../controlprogram/paramserv/PSWorker.java      | 24 ++++-
 .../paramserv/ParamservUtils.java               |  5 ++
 .../paramserv/spark/SparkPSProxy.java           | 35 ++++++--
 .../paramserv/spark/SparkPSWorker.java          | 93 ++++++++++++++++---
 .../paramserv/spark/rpc/PSRpcCall.java          | 92 +++++++++----------
 .../paramserv/spark/rpc/PSRpcFactory.java       | 24 ++---
 .../paramserv/spark/rpc/PSRpcHandler.java       | 32 ++++---
 .../paramserv/spark/rpc/PSRpcObject.java        | 85 ++++++++++++++----
 .../paramserv/spark/rpc/PSRpcResponse.java      | 94 ++++++++++----------
 .../cp/ParamservBuiltinCPInstruction.java       | 56 ++++++++----
 .../sysml/runtime/io/IOUtilFunctions.java       | 10 +++
 .../java/org/apache/sysml/utils/Statistics.java |  4 +
 .../paramserv/ParamservSparkNNTest.java         | 37 +++-----
 .../functions/paramserv/RpcObjectTest.java      | 22 ++---
 .../paramserv-spark-agg-service-failed.dml      |  6 +-
 .../paramserv/paramserv-spark-worker-failed.dml |  6 +-
 18 files changed, 441 insertions(+), 238 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
index f524251..0265c33 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
@@ -364,6 +364,12 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
 	// ***                                       ***
 	// *********************************************
 
+	public T acquireReadAndRelease() {
+		T tmp = acquireRead();
+		release();
+		return tmp;
+	}
+	
 	/**
 	 * Acquires a shared "read-only" lock, produces the reference to the cache block,
 	 * restores the cache block to main memory, reads from HDFS if needed.

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index c23943d..b8a416f 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -48,12 +48,10 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 	public String getWorkerName() {
 		return String.format("Local worker_%d", _workerID);
 	}
-
+	
 	@Override
 	public Void call() throws Exception {
-		if (DMLScript.STATISTICS)
-			Statistics.incWorkerNumber();
-		
+		incWorkerNumber();
 		try {
 			long dataSize = _features.getNumRows();
 			int totalIter = (int) Math.ceil((double) dataSize / _batchSize);
@@ -94,17 +92,19 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 				if( j < totalIter - 1 )
 					params = updateModel(params, gradients, i, j, totalIter);
 				ParamservUtils.cleanupListObject(_ec, gradients);
+				
+				accNumBatches(1);
 			}
 
 			// Push the gradients to ps
 			pushGradients(accGradients);
 			ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL);
 
+			accNumEpochs(1);
 			if (LOG.isDebugEnabled()) {
 				LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1));
 			}
 		}
-
 	}
 
 	private ListObject updateModel(ListObject globalParams, ListObject gradients, int i, int j, int totalIter) {
@@ -112,8 +112,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 
 		globalParams = _ps.updateLocalModel(_ec, gradients, globalParams);
 
-		if (DMLScript.STATISTICS)
-			Statistics.accPSLocalModelUpdateTime((long) tUpd.stop());
+		accLocalModelUpdateTime(tUpd);
 		
 		if (LOG.isDebugEnabled()) {
 			LOG.debug(String.format("%s: local global parameter [size:%d kb] updated. "
@@ -133,9 +132,12 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 
 				// Push the gradients to ps
 				pushGradients(gradients);
-
 				ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL);
+				
+				accNumBatches(1);
 			}
+			
+			accNumEpochs(1);
 			if (LOG.isDebugEnabled()) {
 				LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1));
 			}
@@ -169,8 +171,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 		Timing tSlic = DMLScript.STATISTICS ? new Timing(true) : null;
 		MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, begin, end);
 		MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, begin, end);
-		if (DMLScript.STATISTICS)
-			Statistics.accPSBatchIndexingTime((long) tSlic.stop());
+		accBatchIndexingTime(tSlic);
 
 		_ec.setVariable(Statement.PS_FEATURES, bFeatures);
 		_ec.setVariable(Statement.PS_LABELS, bLabels);
@@ -185,8 +186,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 		// Invoke the update function
 		Timing tGrad = DMLScript.STATISTICS ? new Timing(true) : null;
 		_inst.processInstruction(_ec);
-		if (DMLScript.STATISTICS)
-			Statistics.accPSGradientComputeTime((long) tGrad.stop());
+		accGradientComputeTime(tGrad);
 
 		// Get the gradients
 		ListObject gradients = (ListObject) _ec.getVariable(_output.getName());
@@ -195,4 +195,28 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 		ParamservUtils.cleanupData(_ec, bLabels);
 		return gradients;
 	}
+	
+	@Override
+	protected void incWorkerNumber() {
+		if (DMLScript.STATISTICS)
+			Statistics.incWorkerNumber();
+	}
+
+	@Override
+	protected void accLocalModelUpdateTime(Timing time) {
+		if (DMLScript.STATISTICS)
+			Statistics.accPSLocalModelUpdateTime((long) time.stop());
+	}
+
+	@Override
+	protected void accBatchIndexingTime(Timing time) {
+		if (DMLScript.STATISTICS)
+			Statistics.accPSBatchIndexingTime((long) time.stop());
+	}
+
+	@Override
+	protected void accGradientComputeTime(Timing time) {
+		if (DMLScript.STATISTICS)
+			Statistics.accPSGradientComputeTime((long) time.stop());
+	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
index 4b5c5c1..5f2d552 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -32,11 +32,12 @@ import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
 
-public abstract class PSWorker implements Serializable {
-
+public abstract class PSWorker implements Serializable 
+{
 	private static final long serialVersionUID = -3510485051178200118L;
 
 	protected int _workerID;
@@ -133,4 +134,23 @@ public abstract class PSWorker implements Serializable {
 	}
 
 	public abstract String getWorkerName();
+
+	/**
+	 * ----- The following methods are dedicated to statistics -------------
+ 	 */
+	protected abstract void incWorkerNumber();
+
+	protected abstract void accLocalModelUpdateTime(Timing time);
+
+	protected abstract void accBatchIndexingTime(Timing time);
+
+	protected abstract void accGradientComputeTime(Timing time);
+
+	protected void accNumEpochs(int n) {
+		//do nothing
+	}
+	
+	protected void accNumBatches(int n) {
+		//do nothing
+	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
index cf27457..9624c55 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -156,11 +156,16 @@ public class ParamservUtils {
 	}
 
 	public static MatrixObject newMatrixObject(MatrixBlock mb) {
+		return newMatrixObject(mb, true);
+	}
+	
+	public static MatrixObject newMatrixObject(MatrixBlock mb, boolean cleanup) {
 		MatrixObject result = new MatrixObject(Expression.ValueType.DOUBLE, OptimizerUtils.getUniqueTempFileName(),
 			new MetaDataFormat(new MatrixCharacteristics(-1, -1, ConfigurationManager.getBlocksize(),
 			ConfigurationManager.getBlocksize()), OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
 		result.acquireModify(mb);
 		result.release();
+		result.enableCleanup(cleanup);
 		return result;
 	}
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
index de7b6c6..48a4883 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
@@ -22,7 +22,10 @@ package org.apache.sysml.runtime.controlprogram.paramserv.spark;
 import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PULL;
 import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PUSH;
 
+import java.io.IOException;
+
 import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.util.LongAccumulator;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
@@ -30,25 +33,35 @@ import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall;
 import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
-import org.apache.sysml.utils.Statistics;
 
 public class SparkPSProxy extends ParamServer {
 
-	private TransportClient _client;
+	private final TransportClient _client;
 	private final long _rpcTimeout;
+	private final LongAccumulator _aRPC;
 
-	public SparkPSProxy(TransportClient client, long rpcTimeout) {
+	public SparkPSProxy(TransportClient client, long rpcTimeout, LongAccumulator aRPC) {
 		super();
 		_client = client;
 		_rpcTimeout = rpcTimeout;
+		_aRPC = aRPC;
+	}
+
+	private void accRpcRequestTime(Timing tRpc) {
+		if (DMLScript.STATISTICS)
+			_aRPC.add((long) tRpc.stop());
 	}
 
 	@Override
 	public void push(int workerID, ListObject value) {
 		Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
-		PSRpcResponse response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PUSH, workerID, value).serialize(), _rpcTimeout));
-		if (DMLScript.STATISTICS)
-			Statistics.accPSRpcRequestTime((long) tRpc.stop());
+		PSRpcResponse response;
+		try {
+			response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PUSH, workerID, value).serialize(), _rpcTimeout));
+		} catch (IOException e) {
+			throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients.", workerID), e);
+		}
+		accRpcRequestTime(tRpc);
 		if (!response.isSuccessful()) {
 			throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients. \n%s", workerID, response.getErrorMessage()));
 		}
@@ -57,9 +70,13 @@ public class SparkPSProxy extends ParamServer {
 	@Override
 	public ListObject pull(int workerID) {
 		Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
-		PSRpcResponse response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PULL, workerID, null).serialize(), _rpcTimeout));
-		if (DMLScript.STATISTICS)
-			Statistics.accPSRpcRequestTime((long) tRpc.stop());
+		PSRpcResponse response;
+		try {
+			response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PULL, workerID, null).serialize(), _rpcTimeout));
+		} catch (IOException e) {
+			throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models.", workerID), e);
+		}
+		accRpcRequestTime(tRpc);
 		if (!response.isSuccessful()) {
 			throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models. \n%s", workerID, response.getErrorMessage()));
 		}

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
index fa06243..59203ad 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
@@ -23,7 +23,9 @@ import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
 
+import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.function.VoidFunction;
+import org.apache.spark.util.LongAccumulator;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.parser.Statement;
 import org.apache.sysml.runtime.codegen.CodegenUtils;
@@ -34,7 +36,6 @@ import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.util.ProgramConverter;
-import org.apache.sysml.utils.Statistics;
 
 import scala.Tuple2;
 
@@ -42,13 +43,21 @@ public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<
 
 	private static final long serialVersionUID = -8674739573419648732L;
 
-	private String _program;
-	private HashMap<String, byte[]> _clsMap;
-	private String _host; // host ip of driver
-	private long _rpcTimeout; // rpc ask timeout
-	private String _aggFunc;
-
-	public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, String host, long rpcTimeout) {
+	private final String _program;
+	private final HashMap<String, byte[]> _clsMap;
+	private final SparkConf _conf;
+	private final int _port; // rpc port
+	private final String _aggFunc;
+	private final LongAccumulator _aSetup; // accumulator for setup time
+	private final LongAccumulator _aWorker; // accumulator for worker number
+	private final LongAccumulator _aUpdate; // accumulator for model update
+	private final LongAccumulator _aIndex; // accumulator for batch indexing
+	private final LongAccumulator _aGrad; // accumulator for gradients computing
+	private final LongAccumulator _aRPC; // accumulator for rpc request
+	private final LongAccumulator _nBatches; //number of executed batches
+	private final LongAccumulator _nEpochs; //number of executed epoches
+	
+	public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches, LongAccumulator aEpochs) {
 		_updFunc = updFunc;
 		_aggFunc = aggFunc;
 		_freq = freq;
@@ -56,21 +65,29 @@ public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<
 		_batchSize = batchSize;
 		_program = program;
 		_clsMap = clsMap;
-		_host = host;
-		_rpcTimeout = rpcTimeout;
+		_conf = conf;
+		_port = port;
+		_aSetup = aSetup;
+		_aWorker = aWorker;
+		_aUpdate = aUpdate;
+		_aIndex = aIndex;
+		_aGrad = aGrad;
+		_aRPC = aRPC;
+		_nBatches = aBatches;
+		_nEpochs = aEpochs;
 	}
 
 	@Override
 	public String getWorkerName() {
 		return String.format("Spark worker_%d", _workerID);
 	}
-
+	
 	@Override
 	public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws Exception {
 		Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
 		configureWorker(input);
-		if (DMLScript.STATISTICS)
-			Statistics.accPSSetupTime((long) tSetup.stop());
+		accSetupTime(tSetup);
+
 		call(); // Launch the worker
 	}
 
@@ -89,8 +106,14 @@ public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<
 		// Initialize the buffer pool and register it in the jvm shutdown hook in order to be cleanuped at the end
 		RemoteParForUtils.setupBufferPool(_workerID);
 
+		// Get some configurations
+		long rpcTimeout = _conf.contains("spark.rpc.askTimeout") ?
+			_conf.getTimeAsMs("spark.rpc.askTimeout") :
+			_conf.getTimeAsMs("spark.network.timeout", "120s");
+		String host = _conf.get("spark.driver.host");
+
 		// Create the ps proxy
-		_ps = PSRpcFactory.createSparkPSProxy(_host, _rpcTimeout);
+		_ps = PSRpcFactory.createSparkPSProxy(_conf, host, _port, rpcTimeout, _aRPC);
 
 		// Initialize the update function
 		setupUpdateFunction(_updFunc, _ec);
@@ -104,4 +127,46 @@ public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<
 		_features.enableCleanup(false);
 		_labels.enableCleanup(false);
 	}
+	
+
+	@Override
+	public void incWorkerNumber() {
+		if (DMLScript.STATISTICS)
+			_aWorker.add(1);
+	}
+	
+	@Override
+	public void accLocalModelUpdateTime(Timing time) {
+		if (DMLScript.STATISTICS)
+			_aUpdate.add((long) time.stop());
+	}
+
+	@Override
+	public void accBatchIndexingTime(Timing time) {
+		if (DMLScript.STATISTICS)
+			_aIndex.add((long) time.stop());
+	}
+
+	@Override
+	public void accGradientComputeTime(Timing time) {
+		if (DMLScript.STATISTICS)
+			_aGrad.add((long) time.stop());
+	}
+	
+	@Override
+	protected void accNumEpochs(int n) {
+		if (DMLScript.STATISTICS)
+			_nEpochs.add(n);
+	}
+	
+	@Override
+	protected void accNumBatches(int n) {
+		if (DMLScript.STATISTICS)
+			_nBatches.add(n);
+	}
+	
+	private void accSetupTime(Timing tSetup) {
+		if (DMLScript.STATISTICS)
+			_aSetup.add((long) tSetup.stop());
+	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
index 999d409..b8f482c 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
@@ -19,71 +19,34 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
 
-import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_BEGIN;
-import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_END;
-import static org.apache.sysml.runtime.util.ProgramConverter.COMPONENTS_DELIM;
-import static org.apache.sysml.runtime.util.ProgramConverter.EMPTY;
-import static org.apache.sysml.runtime.util.ProgramConverter.LEVELIN;
-import static org.apache.sysml.runtime.util.ProgramConverter.LEVELOUT;
-
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.IOException;
 import java.nio.ByteBuffer;
-import java.util.StringTokenizer;
 
+import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
-import org.apache.sysml.runtime.util.ProgramConverter;
+import org.apache.sysml.runtime.io.IOUtilFunctions;
+import org.apache.sysml.runtime.util.FastBufferedDataOutputStream;
 
 public class PSRpcCall extends PSRpcObject {
 
-	private static final String PS_RPC_CALL_BEGIN = CDATA_BEGIN + "PSRPCCALL" + LEVELIN;
-	private static final String PS_RPC_CALL_END = LEVELOUT + CDATA_END;
-
-	private String _method;
+	private int _method;
 	private int _workerID;
 	private ListObject _data;
 
-	public PSRpcCall(String method, int workerID, ListObject data) {
+	public PSRpcCall(int method, int workerID, ListObject data) {
 		_method = method;
 		_workerID = workerID;
 		_data = data;
 	}
 
-	public PSRpcCall(ByteBuffer buffer) {
+	public PSRpcCall(ByteBuffer buffer) throws IOException {
 		deserialize(buffer);
 	}
 
-	public void deserialize(ByteBuffer buffer) {
-		//FIXME: instead of shallow deserialize + read, we should do a deep deserialize of the matrix blocks.
-		String input = bufferToString(buffer);
-		//header elimination
-		input = input.substring(PS_RPC_CALL_BEGIN.length(), input.length() - PS_RPC_CALL_END.length()); //remove start/end
-		StringTokenizer st = new StringTokenizer(input, COMPONENTS_DELIM);
-
-		_method = st.nextToken();
-		_workerID = Integer.valueOf(st.nextToken());
-		String dataStr = st.nextToken();
-		_data = dataStr.equals(EMPTY) ? null :
-			(ListObject) ProgramConverter.parseDataObject(dataStr)[1];
-	}
-
-	public ByteBuffer serialize() {
-		//FIXME: instead of export+shallow serialize, we should do a deep serialize of the matrix blocks.
-		StringBuilder sb = new StringBuilder();
-		sb.append(PS_RPC_CALL_BEGIN);
-		sb.append(_method);
-		sb.append(COMPONENTS_DELIM);
-		sb.append(_workerID);
-		sb.append(COMPONENTS_DELIM);
-		if (_data == null) {
-			sb.append(EMPTY);
-		} else {
-			flushListObject(_data);
-			sb.append(ProgramConverter.serializeDataObject(DATA_KEY, _data));
-		}
-		sb.append(PS_RPC_CALL_END);
-		return ByteBuffer.wrap(sb.toString().getBytes());
-	}
-
-	public String getMethod() {
+	public int getMethod() {
 		return _method;
 	}
 
@@ -94,4 +57,37 @@ public class PSRpcCall extends PSRpcObject {
 	public ListObject getData() {
 		return _data;
 	}
+	
+	public void deserialize(ByteBuffer buffer) throws IOException {
+		DataInputStream dis = new DataInputStream(
+			new ByteArrayInputStream(IOUtilFunctions.getBytes(buffer)));
+		_method = dis.readInt();
+		validateMethod(_method);
+		_workerID = dis.readInt();
+		if (dis.available() > 1)
+			_data = readAndDeserialize(dis);
+		dis.close();
+	}
+
+	public ByteBuffer serialize() throws IOException {
+		//TODO: Perf: use CacheDataOutput to avoid multiple copies (needs UTF handling)
+		ByteArrayOutputStream bos = new ByteArrayOutputStream(getApproxSerializedSize(_data));
+		FastBufferedDataOutputStream dos = new FastBufferedDataOutputStream(bos);
+		dos.writeInt(_method);
+		dos.writeInt(_workerID);
+		if (_data != null)
+			serializeAndWriteListObject(_data, dos);
+		dos.flush();
+		return ByteBuffer.wrap(bos.toByteArray());
+	}
+	
+	private void validateMethod(int method) {
+		switch (method) {
+			case PUSH:
+			case PULL:
+				break;
+			default:
+				throw new DMLRuntimeException("PSRpcCall: only support rpc method 'push' or 'pull'");
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
index c8b4024..2d921de 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
@@ -22,36 +22,36 @@ package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
 import java.io.IOException;
 import java.util.Collections;
 
+import org.apache.spark.SparkConf;
 import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.netty.SparkTransportConf;
 import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.util.SystemPropertyConfigProvider;
 import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.util.LongAccumulator;
 import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
 import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSProxy;
 
-//TODO should be able to configure the port by users
 public class PSRpcFactory {
 
-	private static final int PORT = 5055;
 	private static final String MODULE_NAME = "ps";
 
-	private static TransportContext createTransportContext(LocalParamServer ps) {
-		TransportConf conf = new TransportConf(MODULE_NAME, new SystemPropertyConfigProvider());
+	private static TransportContext createTransportContext(SparkConf conf, LocalParamServer ps) {
+		TransportConf tc = SparkTransportConf.fromSparkConf(conf, MODULE_NAME, 0);;
 		PSRpcHandler handler = new PSRpcHandler(ps);
-		return new TransportContext(conf, handler);
+		return new TransportContext(tc, handler);
 	}
 
 	/**
 	 * Create and start the server
 	 * @return server
 	 */
-	public static TransportServer createServer(LocalParamServer ps, String host) {
-		TransportContext context = createTransportContext(ps);
-		return context.createServer(host, PORT, Collections.emptyList());
+	public static TransportServer createServer(SparkConf conf, LocalParamServer ps, String host) {
+		TransportContext context = createTransportContext(conf, ps);
+		return context.createServer(host, 0, Collections.emptyList());	// bind rpc to an ephemeral port
 	}
 
-	public static SparkPSProxy createSparkPSProxy(String host, long rpcTimeout) throws IOException {
-		TransportContext context = createTransportContext(new LocalParamServer());
-		return new SparkPSProxy(context.createClientFactory().createClient(host, PORT), rpcTimeout);
+	public static SparkPSProxy createSparkPSProxy(SparkConf conf, String host, int port, long rpcTimeout, LongAccumulator aRPC) throws IOException {
+		TransportContext context = createTransportContext(conf, new LocalParamServer());
+		return new SparkPSProxy(context.createClientFactory().createClient(host, port), rpcTimeout, aRPC);
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
index 3d73a37..a2c311e 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
@@ -21,10 +21,8 @@ package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
 
 import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PULL;
 import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PUSH;
-import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.EMPTY_DATA;
-import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.ERROR;
-import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.SUCCESS;
 
+import java.io.IOException;
 import java.nio.ByteBuffer;
 
 import org.apache.commons.lang.exception.ExceptionUtils;
@@ -35,6 +33,7 @@ import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.StreamManager;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.Type;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
 
 public final class PSRpcHandler extends RpcHandler {
@@ -47,28 +46,41 @@ public final class PSRpcHandler extends RpcHandler {
 
 	@Override
 	public void receive(TransportClient client, ByteBuffer buffer, RpcResponseCallback callback) {
-		PSRpcCall call = new PSRpcCall(buffer);
+		PSRpcCall call;
+		try {
+			call = new PSRpcCall(buffer);
+		} catch (IOException e) {
+			throw new DMLRuntimeException("PSRpcHandler: some error occurred when deserializing the rpc call.", e);
+		}
 		PSRpcResponse response = null;
 		switch (call.getMethod()) {
 			case PUSH:
 				try {
 					_server.push(call.getWorkerID(), call.getData());
-					response = new PSRpcResponse(SUCCESS, EMPTY_DATA);
+					response = new PSRpcResponse(Type.SUCCESS_EMPTY);
 				} catch (DMLRuntimeException exception) {
-					response = new PSRpcResponse(ERROR, ExceptionUtils.getFullStackTrace(exception));
+					response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception));
 				} finally {
-					callback.onSuccess(response.serialize());
+					try {
+						callback.onSuccess(response.serialize());
+					} catch (IOException e) {
+						throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e);
+					}
 				}
 				break;
 			case PULL:
 				ListObject data;
 				try {
 					data = _server.pull(call.getWorkerID());
-					response = new PSRpcResponse(SUCCESS, data);
+					response = new PSRpcResponse(Type.SUCCESS, data);
 				} catch (DMLRuntimeException exception) {
-					response = new PSRpcResponse(ERROR, ExceptionUtils.getFullStackTrace(exception));
+					response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception));
 				} finally {
-					callback.onSuccess(response.serialize());
+					try {
+						callback.onSuccess(response.serialize());
+					} catch (IOException e) {
+						throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e);
+					}
 				}
 				break;
 			default:

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
index c6d7fd3..7d3353f 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
@@ -19,39 +19,86 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
 
-import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 
 public abstract class PSRpcObject {
 
-	public static final String PUSH = "push";
-	public static final String PULL = "pull";
-	public static final String DATA_KEY = "data";
-	public static final String EMPTY_DATA = "";
+	public static final int PUSH = 1;
+	public static final int PULL = 2;
 
-	public abstract void deserialize(ByteBuffer buffer);
+	public abstract void deserialize(ByteBuffer buffer) throws IOException;
 
-	public abstract ByteBuffer serialize();
+	public abstract ByteBuffer serialize() throws IOException;
 
 	/**
-	 * Convert direct byte buffer to string
-	 * @param buffer direct byte buffer
-	 * @return string
+	 * Deep serialize and write of a list object (currently only support list containing matrices)
+	 * @param lo a list object containing only matrices
+	 * @param dos output data to write to
 	 */
-	protected String bufferToString(ByteBuffer buffer) {
-		byte[] result = new byte[buffer.limit()];
-		buffer.get(result, 0, buffer.limit());
-		return new String(result);
+	protected void serializeAndWriteListObject(ListObject lo, DataOutput dos) throws IOException {
+		validateListObject(lo);
+		dos.writeInt(lo.getLength()); //write list length
+		dos.writeBoolean(lo.isNamedList()); //write list named
+		for (int i = 0; i < lo.getLength(); i++) {
+			if (lo.isNamedList())
+				dos.writeUTF(lo.getName(i)); //write name
+			((MatrixObject) lo.getData().get(i))
+				.acquireReadAndRelease().write(dos); //write matrix
+		}
+	}
+	
+	protected ListObject readAndDeserialize(DataInput dis) throws IOException {
+		int listLen = dis.readInt();
+		List<Data> data = new ArrayList<>();
+		List<String> names = dis.readBoolean() ?
+			new ArrayList<>() : null;
+		for(int i=0; i<listLen; i++) {
+			if( names != null )
+				names.add(dis.readUTF());
+			MatrixBlock mb = new MatrixBlock();
+			mb.readFields(dis);
+			data.add(ParamservUtils.newMatrixObject(mb, false));
+		}
+		return new ListObject(data, names);
 	}
 
 	/**
-	 * Flush the data into HDFS
-	 * @param data list object
+	 * Get serialization size of a list object
+	 * (scheme: size|name|size|matrix)
+	 * @param lo list object
+	 * @return serialization size
 	 */
-	protected void flushListObject(ListObject data) {
-		data.getData().stream().filter(d -> d instanceof CacheableData)
-			.forEach(d -> ((CacheableData<?>) d).exportData());
+	protected int getApproxSerializedSize(ListObject lo) {
+		if( lo == null ) return 0;
+		long result = 4 + 1; // list length and of named
+		result += lo.getLength() * (Integer.BYTES); // bytes for the size of names
+		if (lo.isNamedList())
+			result += lo.getNames().stream().mapToLong(s -> s.length()).sum();
+		result += lo.getData().stream().mapToLong(d ->
+			((MatrixObject)d).acquireReadAndRelease().getExactSizeOnDisk()).sum();
+		if( result > Integer.MAX_VALUE )
+			throw new DMLRuntimeException("Serialized size ("+result+") larger than Integer.MAX_VALUE.");
+		return (int) result;
+	}
+
+	private void validateListObject(ListObject lo) {
+		for (Data d : lo.getData()) {
+			if (!(d instanceof MatrixObject)) {
+				throw new DMLRuntimeException(String.format("Paramserv func:"
+					+ " Unsupported deep serialize of %s, which is not matrix.", d.getDebugName()));
+			}
+		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
index 998c523..3517491 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
@@ -19,41 +19,43 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
 
-import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_BEGIN;
-import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_END;
-import static org.apache.sysml.runtime.util.ProgramConverter.COMPONENTS_DELIM;
-import static org.apache.sysml.runtime.util.ProgramConverter.EMPTY;
-import static org.apache.sysml.runtime.util.ProgramConverter.LEVELIN;
-import static org.apache.sysml.runtime.util.ProgramConverter.LEVELOUT;
-
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.IOException;
 import java.nio.ByteBuffer;
-import java.util.StringTokenizer;
 
 import org.apache.sysml.runtime.instructions.cp.ListObject;
-import org.apache.sysml.runtime.util.ProgramConverter;
+import org.apache.sysml.runtime.io.IOUtilFunctions;
+import org.apache.sysml.runtime.util.FastBufferedDataOutputStream;
 
 public class PSRpcResponse extends PSRpcObject {
+	public enum Type  {
+		SUCCESS,
+		SUCCESS_EMPTY,
+		ERROR,
+	}
+	
+	private Type _status;
+	private Object _data; // Could be list object or exception
 
-	public static final int SUCCESS = 1;
-	public static final int ERROR = 2;
-
-	private static final String PS_RPC_RESPONSE_BEGIN = CDATA_BEGIN + "PSRPCRESPONSE" + LEVELIN;
-	private static final String PS_RPC_RESPONSE_END = LEVELOUT + CDATA_END;
-
-	private int _status;
-	private Object _data;	// Could be list object or exception
-
-	public PSRpcResponse(ByteBuffer buffer) {
+	public PSRpcResponse(ByteBuffer buffer) throws IOException {
 		deserialize(buffer);
 	}
 
-	public PSRpcResponse(int status, Object data) {
+	public PSRpcResponse(Type status) {
+		this(status, null);
+	}
+	
+	public PSRpcResponse(Type status, Object data) {
 		_status = status;
 		_data = data;
+		if( _status == Type.SUCCESS && data == null )
+			_status = Type.SUCCESS_EMPTY;
 	}
 
 	public boolean isSuccessful() {
-		return _status == SUCCESS;
+		return _status != Type.ERROR;
 	}
 
 	public String getErrorMessage() {
@@ -65,48 +67,42 @@ public class PSRpcResponse extends PSRpcObject {
 	}
 
 	@Override
-	public void deserialize(ByteBuffer buffer) {
-		//FIXME: instead of shallow deserialize + read, we should do a deep deserialize of the matrix blocks.
-		String input = bufferToString(buffer);
-		//header elimination
-		input = input.substring(PS_RPC_RESPONSE_BEGIN.length(), input.length() - PS_RPC_RESPONSE_END.length()); //remove start/end
-		StringTokenizer st = new StringTokenizer(input, COMPONENTS_DELIM);
-
-		_status = Integer.valueOf(st.nextToken());
-		String data = st.nextToken();
+	public void deserialize(ByteBuffer buffer) throws IOException {
+		DataInputStream dis = new DataInputStream(
+			new ByteArrayInputStream(IOUtilFunctions.getBytes(buffer)));
+		_status = Type.values()[dis.readInt()];
 		switch (_status) {
 			case SUCCESS:
-				_data = data.equals(EMPTY) ? null :
-					ProgramConverter.parseDataObject(data)[1];
+				_data = readAndDeserialize(dis);
+				break;
+			case SUCCESS_EMPTY:
 				break;
 			case ERROR:
-				_data = data;
+				_data = dis.readUTF();
 				break;
 		}
+		dis.close();
 	}
 
 	@Override
-	public ByteBuffer serialize() {
-		//FIXME: instead of export+shallow serialize, we should do a deep serialize of the matrix blocks.
-		
-		StringBuilder sb = new StringBuilder();
-		sb.append(PS_RPC_RESPONSE_BEGIN);
-		sb.append(_status);
-		sb.append(COMPONENTS_DELIM);
+	public ByteBuffer serialize() throws IOException {
+		//TODO: Perf: use CacheDataOutput to avoid multiple copies (needs UTF handling)
+		int len = 4 + (_status==Type.SUCCESS ? getApproxSerializedSize((ListObject)_data) :
+			_status==Type.SUCCESS_EMPTY ? 0 : ((String)_data).length());
+		ByteArrayOutputStream bos = new ByteArrayOutputStream(len);
+		FastBufferedDataOutputStream dos = new FastBufferedDataOutputStream(bos);
+		dos.writeInt(_status.ordinal());
 		switch (_status) {
 			case SUCCESS:
-				if (_data.equals(EMPTY_DATA)) {
-					sb.append(EMPTY);
-				} else {
-					flushListObject((ListObject) _data);
-					sb.append(ProgramConverter.serializeDataObject(DATA_KEY, (ListObject) _data));
-				}
+				serializeAndWriteListObject((ListObject) _data, dos);
+				break;
+			case SUCCESS_EMPTY:
 				break;
 			case ERROR:
-				sb.append(_data.toString());
+				dos.writeUTF(_data.toString());
 				break;
 		}
-		sb.append(PS_RPC_RESPONSE_END);
-		return ByteBuffer.wrap(sb.toString().getBytes());
+		dos.flush();
+		return ByteBuffer.wrap(bos.toByteArray());
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 6133987..fe238bd 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -56,6 +56,7 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.log4j.Level;
 import org.apache.log4j.Logger;
 import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.util.LongAccumulator;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.recompile.Recompiler;
 import org.apache.sysml.lops.LopProperties;
@@ -125,11 +126,25 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 
 		// Get the compiled execution context
 		LocalVariableMap newVarsMap = createVarsMap(sec);
-		ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1); // level of par is 1 in spark backend
+		// Level of par is 1 in spark backend because one worker will be launched per task
+		ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1);
 
 		MatrixObject features = sec.getMatrixObject(getParam(PS_FEATURES));
 		MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS));
 
+		// Create the agg service's execution context
+		ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0);
+
+		// Create the parameter server
+		ListObject model = sec.getListObject(getParam(PS_MODEL));
+		ParamServer ps = createPS(mode, aggFunc, getUpdateType(), workerNum, model, aggServiceEC);
+
+		// Get driver host
+		String host = sec.getSparkContext().getConf().get("spark.driver.host");
+
+		// Create the netty server for ps
+		TransportServer server = PSRpcFactory.createServer(sec.getSparkContext().getConf(),(LocalParamServer) ps, host); // Start the server
+
 		// Force all the instructions to CP type
 		Recompiler.recompileProgramBlockHierarchy2Forced(
 			newEC.getProgram().getProgramBlocks(), 0, new HashSet<>(), LopProperties.ExecType.CP);
@@ -139,29 +154,24 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 		HashMap<String, byte[]> clsMap = new HashMap<>();
 		String program = ProgramConverter.serializeSparkPSBody(body, clsMap);
 
-		// Get some configurations
-		String host = sec.getSparkContext().getConf().get("spark.driver.host");
-		long rpcTimeout = sec.getSparkContext().getConf().contains("spark.rpc.askTimeout") ? 
-			sec.getSparkContext().getConf().getTimeAsMs("spark.rpc.askTimeout") :
-			sec.getSparkContext().getConf().getTimeAsMs("spark.network.timeout", "120s");
+		// Add the accumulators for statistics
+		LongAccumulator aSetup = sec.getSparkContext().sc().longAccumulator("setup");
+		LongAccumulator aWorker = sec.getSparkContext().sc().longAccumulator("workersNum");
+		LongAccumulator aUpdate = sec.getSparkContext().sc().longAccumulator("modelUpdate");
+		LongAccumulator aIndex = sec.getSparkContext().sc().longAccumulator("batchIndex");
+		LongAccumulator aGrad = sec.getSparkContext().sc().longAccumulator("gradCompute");
+		LongAccumulator aRPC = sec.getSparkContext().sc().longAccumulator("rpcRequest");
+		LongAccumulator aBatch = sec.getSparkContext().sc().longAccumulator("numBatches");
+		LongAccumulator aEpoch = sec.getSparkContext().sc().longAccumulator("numEpochs");
 
 		// Create remote workers
-		SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), getFrequency(),
-			getEpochs(), getBatchSize(), program, clsMap, host, rpcTimeout);
-
-		// Create the agg service's execution context
-		ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0);
-
-		// Create the parameter server
-		ListObject model = sec.getListObject(getParam(PS_MODEL));
-		ParamServer ps = createPS(mode, aggFunc, getUpdateType(), workerNum, model, aggServiceEC);
+		SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), 
+			getFrequency(), getEpochs(), getBatchSize(), program, clsMap, sec.getSparkContext().getConf(),
+			server.getPort(), aSetup, aWorker, aUpdate, aIndex, aGrad, aRPC, aBatch, aEpoch);
 
 		if (DMLScript.STATISTICS)
 			Statistics.accPSSetupTime((long) tSetup.stop());
 
-		// Create the netty server for ps
-		TransportServer server = PSRpcFactory.createServer((LocalParamServer) ps, host); // Start the server
-
 		try {
 			ParamservUtils.doPartitionOnSpark(sec, features, labels, scheme, workerNum) // Do data partitioning
 				.foreach(worker); // Run remote workers
@@ -172,6 +182,16 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
 			server.close();
 		}
 
+		// Accumulate the statistics for remote workers
+		if (DMLScript.STATISTICS) {
+			Statistics.accPSSetupTime(aSetup.sum());
+			Statistics.incWorkerNumber(aWorker.sum());
+			Statistics.accPSLocalModelUpdateTime(aUpdate.sum());
+			Statistics.accPSBatchIndexingTime(aIndex.sum());
+			Statistics.accPSGradientComputeTime(aGrad.sum());
+			Statistics.accPSRpcRequestTime(aRPC.sum());
+		}
+
 		// Fetch the final model from ps
 		sec.setVariable(output.getName(), ps.getResult());
 	}

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java b/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java
index 94941a1..18f0e54 100644
--- a/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java
+++ b/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java
@@ -26,6 +26,7 @@ import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStreamReader;
 import java.io.StringReader;
+import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Comparator;
@@ -608,4 +609,13 @@ public class IOUtilFunctions
 		ba[ off+6 ] = (byte)((val >>>  8) & 0xFF);
 		ba[ off+7 ] = (byte)((val >>>  0) & 0xFF);
 	}
+	
+	public static byte[] getBytes(ByteBuffer buff) {
+		int len = buff.limit();
+		if( buff.hasArray() )
+			return Arrays.copyOf(buff.array(), len);
+		byte[] ret = new byte[len];
+		buff.get(ret, buff.position(), len);
+		return ret;
+	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/utils/Statistics.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java
index 1dd8362..44667ba 100644
--- a/src/main/java/org/apache/sysml/utils/Statistics.java
+++ b/src/main/java/org/apache/sysml/utils/Statistics.java
@@ -541,6 +541,10 @@ public class Statistics
 		psNumWorkers.increment();
 	}
 
+	public static void incWorkerNumber(long n) {
+		psNumWorkers.add(n);
+	}
+
 	public static void accPSSetupTime(long t) {
 		psSetupTime.add(t);
 	}

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
index 30eccb3..89235d7 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
@@ -1,15 +1,8 @@
 package org.apache.sysml.test.integration.functions.paramserv;
 
-import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.sysml.api.DMLException;
 import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.mlcontext.MLContext;
-import org.apache.sysml.api.mlcontext.Script;
 import org.apache.sysml.parser.Statement;
-import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.junit.Test;
@@ -42,12 +35,12 @@ public class ParamservSparkNNTest extends AutomatedTestBase {
 
 	@Test
 	public void testParamservBSPEpochDisjointContiguous() {
-		runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+		runDMLTest(5, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
 	}
 
 	@Test
 	public void testParamservASPEpochDisjointContiguous() {
-		runDMLTest(10, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+		runDMLTest(5, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
 	}
 
 	@Test
@@ -62,9 +55,14 @@ public class ParamservSparkNNTest extends AutomatedTestBase {
 
 	private void runDMLTest(String testname, boolean exceptionExpected, Class<?> expectedException, String errMessage) {
 		programArgs = new String[] { "-explain" };
+		internalRunDMLTest(testname, exceptionExpected, expectedException, errMessage);
+	}
+
+	private void internalRunDMLTest(String testname, boolean exceptionExpected, Class<?> expectedException,
+			String errMessage) {
 		DMLScript.RUNTIME_PLATFORM oldRtplatform = AutomatedTestBase.rtplatform;
 		boolean oldUseLocalSparkConfig = DMLScript.USE_LOCAL_SPARK_CONFIG;
-		AutomatedTestBase.rtplatform = DMLScript.RUNTIME_PLATFORM.SPARK;
+		AutomatedTestBase.rtplatform = DMLScript.RUNTIME_PLATFORM.HYBRID_SPARK;
 		DMLScript.USE_LOCAL_SPARK_CONFIG = true;
 
 		try {
@@ -80,22 +78,7 @@ public class ParamservSparkNNTest extends AutomatedTestBase {
 	}
 
 	private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme) {
-		Script script = dmlFromFile(SCRIPT_DIR + TEST_DIR + TEST_NAME1 + ".dml").in("$mode", Statement.PSModeType.REMOTE_SPARK.toString())
-			.in("$epochs", String.valueOf(epochs))
-			.in("$workers", String.valueOf(workers))
-			.in("$utype", utype.toString())
-			.in("$freq", freq.toString())
-			.in("$batchsize", String.valueOf(batchsize))
-			.in("$scheme", scheme.toString());
-
-		SparkConf conf = SparkExecutionContext.createSystemMLSparkConf().setAppName("ParamservSparkNNTest").setMaster("local[*]")
-			.set("spark.driver.allowMultipleContexts", "true");
-		JavaSparkContext sc = new JavaSparkContext(conf);
-		MLContext ml = new MLContext(sc);
-		ml.setStatistics(true);
-		ml.execute(script);
-		ml.resetConfig();
-		sc.stop();
-		ml.close();
+		programArgs = new String[] { "-explain", "-nvargs", "mode=REMOTE_SPARK", "epochs=" + epochs, "workers=" + workers, "utype=" + utype, "freq=" + freq, "batchsize=" + batchsize, "scheme=" + scheme};
+		internalRunDMLTest(TEST_NAME1, false, null, null);
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
index 57e1106..f2df1e6 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
@@ -19,13 +19,13 @@
 
 package org.apache.sysml.test.integration.functions.paramserv;
 
+import java.io.IOException;
 import java.util.Arrays;
 
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall;
 import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject;
 import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse;
-import org.apache.sysml.runtime.instructions.cp.IntObject;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
 import org.junit.Assert;
 import org.junit.Test;
@@ -33,24 +33,26 @@ import org.junit.Test;
 public class RpcObjectTest {
 
 	@Test
-	public void testPSRpcCall() {
+	public void testPSRpcCall() throws IOException {
 		MatrixObject mo1 = SerializationTest.generateDummyMatrix(10);
 		MatrixObject mo2 = SerializationTest.generateDummyMatrix(20);
-		IntObject io = new IntObject(30);
-		ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
+		ListObject lo = new ListObject(Arrays.asList(mo1, mo2));
 		PSRpcCall expected = new PSRpcCall(PSRpcObject.PUSH, 1, lo);
 		PSRpcCall actual = new PSRpcCall(expected.serialize());
-		Assert.assertEquals(new String(expected.serialize().array()), new String(actual.serialize().array()));
+		Assert.assertTrue(Arrays.equals(
+			expected.serialize().array(),
+			actual.serialize().array()));
 	}
 
 	@Test
-	public void testPSRpcResponse() {
+	public void testPSRpcResponse() throws IOException {
 		MatrixObject mo1 = SerializationTest.generateDummyMatrix(10);
 		MatrixObject mo2 = SerializationTest.generateDummyMatrix(20);
-		IntObject io = new IntObject(30);
-		ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
-		PSRpcResponse expected = new PSRpcResponse(PSRpcResponse.SUCCESS, lo);
+		ListObject lo = new ListObject(Arrays.asList(mo1, mo2));
+		PSRpcResponse expected = new PSRpcResponse(PSRpcResponse.Type.SUCCESS, lo);
 		PSRpcResponse actual = new PSRpcResponse(expected.serialize());
-		Assert.assertEquals(new String(expected.serialize().array()), new String(actual.serialize().array()));
+		Assert.assertTrue(Arrays.equals(
+			expected.serialize().array(),
+			actual.serialize().array()));
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml b/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml
index 4d0f32e..d1edc29 100644
--- a/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml
@@ -19,7 +19,7 @@
 #
 #-------------------------------------------------------------
 
-e1 = "element1"
+e1 = matrix(1, rows=100, cols=10)
 modelList = list(e1)
 X = matrix(1, rows=200, cols=30)
 Y = matrix(2, rows=200, cols=1)
@@ -42,11 +42,9 @@ aggregation = function(list[unknown] model,
   print(toString(as.matrix(gradients["agg_service_err"])))
 }
 
-e2 = "element2"
+e2 = matrix(2, rows=100, cols=10)
 params = list(e2)
 
-modelList = list("model")
-
 # Use paramserv function
 modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="REMOTE_SPARK", utype="BSP", epochs=10, hyperparams=params, k=1)
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml b/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml
index ad16122..bf0de68 100644
--- a/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml
@@ -19,7 +19,7 @@
 #
 #-------------------------------------------------------------
 
-e1 = "element1"
+e1 = matrix(1, rows=100, cols=10)
 modelList = list(e1)
 X = matrix(1, rows=200, cols=30)
 Y = matrix(2, rows=200, cols=1)
@@ -42,11 +42,9 @@ aggregation = function(list[unknown] model,
   modelResult = model
 }
 
-e2 = "element2"
+e2 = matrix(2, rows=100, cols=10)
 params = list(e2)
 
-modelList = list("model")
-
 # Use paramserv function
 modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="REMOTE_SPARK", utype="BSP", epochs=10, hyperparams=params, k=1)