You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2019/03/29 17:36:03 UTC

[systemml] branch master updated: [SYSTEMML-2525] Initial implementation of RESTful model serving system

This is an automated email from the ASF dual-hosted git repository.

niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new 863c9d5  [SYSTEMML-2525] Initial implementation of RESTful model serving system
863c9d5 is described below

commit 863c9d5cb1752b0e50140f5c6673968b57c2f9d0
Author: Anthony Thomas <ah...@eng.ucsd.edu>
AuthorDate: Fri Mar 29 10:27:54 2019 -0700

    [SYSTEMML-2525] Initial implementation of RESTful model serving system
    
    - The current implementation extends JMLC's readMatrix and GPUContext API.
    - The serving system is implemented in Scala using Akka and is available in the org.apache.sysml.api.ml.serving.
    - Minor cleanup and refactoring required before it's ready to be used by the general public will be done in subsequent commits.
    - It still remains unclear whether CUDA and Serving code should be included in future standalone releases. If yes, it will greatly simplify the deployment, else the user will have to build standalone jar before deployment.
    - The serving system can be started by:
    ```
    mvn -Djcuda.scope=compile -Dserving.scope=compile package -P standalone-jar
    java -jar systemml-*-standalone.jar org.apache.sysml.api.ml.serving.PredictionService -port 8099 -scheduler scheduler -admin_password admin
    ```
    - The model can registered using http://localhost:8099/register-model and user can invoke prediction using http://localhost:8099/predict service.
    
    Closes #860.
---
 .travis.yml                                        |   5 +-
 pom.xml                                            |  31 ++
 .../java/org/apache/sysml/api/jmlc/Connection.java |  51 +++
 .../org/apache/sysml/api/jmlc/PreparedScript.java  |  18 +
 .../org/apache/sysml/parser/DataExpression.java    |   1 +
 .../runtime/controlprogram/LocalVariableMap.java   |   4 +
 .../org/apache/sysml/utils/PersistentLRUCache.java |  97 ++--
 .../api/ml/serving/BasicBatchingScheduler.scala    |  93 ++++
 .../sysml/api/ml/serving/BatchingScheduler.scala   |  99 +++++
 .../sysml/api/ml/serving/BatchingUtils.scala       |  57 +++
 .../org/apache/sysml/api/ml/serving/Executor.scala | 155 +++++++
 .../api/ml/serving/LocalityAwareScheduler.scala    | 218 +++++++++
 .../apache/sysml/api/ml/serving/ModelManager.scala | 176 ++++++++
 .../api/ml/serving/NonBatchingScheduler.scala      |  69 +++
 .../sysml/api/ml/serving/PredictionService.scala   | 490 +++++++++++++++++++++
 .../apache/sysml/api/ml/serving/RLSEstimator.scala |  91 ++++
 .../apache/sysml/api/ml/serving/Scheduler.scala    | 133 ++++++
 .../sysml/api/ml/serving/SchedulerFactory.scala    |  29 ++
 18 files changed, 1756 insertions(+), 61 deletions(-)

diff --git a/.travis.yml b/.travis.yml
index a0c308b..3ce9d06 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -46,7 +46,8 @@ before_script:
 
 script:
 #  - mvn clean verify jacoco:report coveralls:report
-  - mvn clean verify
+# The -q parameter tells mvn not to display anything other than ERROR level log messages. This is required because travis kills the job after the log length exceeds its maximum log length (usually 4 MB).
+  - mvn -q clean verify
 
 after_success:
-#  -  mvn test jacoco:report coveralls:report
\ No newline at end of file
+#  -  mvn test jacoco:report coveralls:report
diff --git a/pom.xml b/pom.xml
index ad74276..4b5dd29 100644
--- a/pom.xml
+++ b/pom.xml
@@ -72,6 +72,7 @@
 		<maven.build.timestamp.format>yyyy-MM-dd HH:mm:ss z</maven.build.timestamp.format>
 		<enableGPU>false</enableGPU>
 		<jcuda.scope>provided</jcuda.scope>
+		<serving.scope>provided</serving.scope>
 		<jcuda.version>0.9.0d</jcuda.version>
 		<!-- OS-specific JVM arguments for running integration tests -->
 		<integrationTestExtraJVMArgs />
@@ -1259,6 +1260,36 @@
 			<version>3.2.0</version>
 		</dependency>
 		<dependency>
+		    <groupId>com.typesafe.akka</groupId>
+		    <artifactId>akka-http_2.11</artifactId>
+		    <version>10.1.3</version>
+		    <scope>${serving.scope}</scope>
+		</dependency>
+		<dependency>
+		    <groupId>com.typesafe.akka</groupId>
+		    <artifactId>akka-actor_2.11</artifactId>
+		    <version>2.5.14</version>
+			<scope>${serving.scope}</scope>
+		</dependency>
+		<dependency>
+		    <groupId>com.typesafe.akka</groupId>
+		    <artifactId>akka-stream_2.11</artifactId>
+		    <version>2.5.14</version>
+			<scope>${serving.scope}</scope>
+		</dependency>
+		<dependency>
+		    <groupId>com.typesafe</groupId>
+		    <artifactId>config</artifactId>
+		    <version>1.2.0</version>
+			<scope>${serving.scope}</scope>
+		</dependency>
+		<dependency>
+		    <groupId>com.typesafe.akka</groupId>
+		    <artifactId>akka-http-spray-json-experimental_2.11</artifactId>
+		    <version>2.4.11.2</version>
+			<scope>${serving.scope}</scope>
+		</dependency>
+		<dependency>
 			<groupId>org.jcuda</groupId>
 			<artifactId>jcuda</artifactId>
 			<version>${jcuda.version}</version>
diff --git a/src/main/java/org/apache/sysml/api/jmlc/Connection.java b/src/main/java/org/apache/sysml/api/jmlc/Connection.java
index 53b7d04..29df4c0 100644
--- a/src/main/java/org/apache/sysml/api/jmlc/Connection.java
+++ b/src/main/java/org/apache/sysml/api/jmlc/Connection.java
@@ -370,6 +370,57 @@ public class Connection implements Closeable
 	// Read matrices
 	////////////////////////////////////////////
 	
+	public MatrixBlock readMatrix(String fname) throws IOException {
+		try {
+			String fnamemtd = DataExpression.getMTDFileName(fname);
+			JSONObject jmtd = new DataExpression().readMetadataFile(fnamemtd, false);
+
+			//parse json meta data
+			long rows = jmtd.getLong(DataExpression.READROWPARAM);
+			long cols = jmtd.getLong(DataExpression.READCOLPARAM);
+			int brlen = jmtd.containsKey(DataExpression.ROWBLOCKCOUNTPARAM)?
+					jmtd.getInt(DataExpression.ROWBLOCKCOUNTPARAM) : -1;
+			int bclen = jmtd.containsKey(DataExpression.COLUMNBLOCKCOUNTPARAM)?
+					jmtd.getInt(DataExpression.COLUMNBLOCKCOUNTPARAM) : -1;
+			long nnz = jmtd.containsKey(DataExpression.READNNZPARAM)?
+					jmtd.getLong(DataExpression.READNNZPARAM) : -1;
+			String format = jmtd.getString(DataExpression.FORMAT_TYPE);
+			InputInfo iinfo = InputInfo.stringExternalToInputInfo(format);
+			return readMatrix(fname, iinfo, rows, cols, brlen, bclen, nnz);
+		} catch (Exception ex) {
+			throw new IOException(ex);
+		}
+	}
+	
+	/**
+	 * Reads an input matrix in arbitrary format from HDFS into a dense double array.
+	 * NOTE: this call currently only supports default configurations for CSV.
+	 *
+	 * @param fname the filename of the input matrix
+	 * @param iinfo InputInfo object
+	 * @param rows number of rows in the matrix
+	 * @param cols number of columns in the matrix
+	 * @param brlen number of rows per block
+	 * @param bclen number of columns per block
+	 * @param nnz number of non-zero values, -1 indicates unknown
+	 * @return matrix as a two-dimensional double array
+	 * @throws IOException if IOException occurs
+	 */
+	public MatrixBlock readMatrix(String fname, InputInfo iinfo, long rows, long cols, int brlen, int bclen, long nnz)
+			throws IOException
+	{
+		setLocalConfigs();
+
+		try {
+			MatrixReader reader = MatrixReaderFactory.createMatrixReader(iinfo);
+			return reader.readMatrixFromHDFS(fname, rows, cols, brlen, bclen, nnz);
+
+		}
+		catch(Exception ex) {
+			throw new IOException(ex);
+		}
+	}
+	
 	/**
 	 * Reads an input matrix in arbitrary format from HDFS into a dense double array.
 	 * NOTE: this call currently only supports default configurations for CSV.
diff --git a/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java b/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java
index 701af30..5926bcc 100644
--- a/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java
+++ b/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java
@@ -75,6 +75,11 @@ public class PreparedScript implements ConfigurableAPI
 	private final HashSet<String> _outVarnames;
 	private final HashMap<String,Data> _inVarReuse;
 	
+	private String name = "";
+	public void setName(String name) {
+		this.name = name;
+	}
+	
 	//internal state (reused)
 	private final Program _prog;
 	private final LocalVariableMap _vars;
@@ -131,6 +136,19 @@ public class PreparedScript implements ConfigurableAPI
 		_cconf = cconf;
 	}
 	
+	public void clearPinnedData() {
+		this._inVarReuse.clear();
+	}
+	
+	public boolean hasPinnedData() { 
+		return _inVarReuse.keySet().size() > 0; 
+	}
+	
+	public void setGpuContext(GPUContext gCtx) { 
+		this._gpuCtx.set(0, gCtx); 
+	}
+	
+	
 	/**
 	 * Sets a boolean flag indicating if runtime statistics should be gathered
 	 * Same behavior as in "MLContext.setStatistics()"
diff --git a/src/main/java/org/apache/sysml/parser/DataExpression.java b/src/main/java/org/apache/sysml/parser/DataExpression.java
index 44f368e..7b64922 100644
--- a/src/main/java/org/apache/sysml/parser/DataExpression.java
+++ b/src/main/java/org/apache/sysml/parser/DataExpression.java
@@ -44,6 +44,7 @@ import org.apache.sysml.runtime.util.UtilFunctions;
 import org.apache.sysml.utils.JSONHelper;
 import org.apache.wink.json4j.JSONArray;
 import org.apache.wink.json4j.JSONObject;
+import org.apache.sysml.parser.Expression.DataOp;
 
 
 public class DataExpression extends DataIdentifier 
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
index cf9e79a..bb06849 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
@@ -131,6 +131,10 @@ public class LocalVariableMap implements Cloneable
 			put(kv.getKey(), kv.getValue());
 		}
 	}
+	
+	public void putAll(LocalVariableMap vars) { 
+		putAll(vars.localMap); 
+	}
 
 	public Data remove( String name ) {
 		Data ret = localMap.remove( name );
diff --git a/src/main/java/org/apache/sysml/utils/PersistentLRUCache.java b/src/main/java/org/apache/sysml/utils/PersistentLRUCache.java
index d9d9337..24e685b 100644
--- a/src/main/java/org/apache/sysml/utils/PersistentLRUCache.java
+++ b/src/main/java/org/apache/sysml/utils/PersistentLRUCache.java
@@ -86,7 +86,7 @@ public class PersistentLRUCache extends LinkedHashMap<String, ValueWrapper> {
 	private String _prefixFilePath;
 	final AtomicLong _currentNumBytes = new AtomicLong();
 	private final long _maxNumBytes;
-	private static final Random _rand = new Random();
+	Random _rand = new Random();
 	boolean isInReadOnlyMode;
 	HashSet<String> persistedKeys = new HashSet<>();
 	
@@ -101,9 +101,6 @@ public class PersistentLRUCache extends LinkedHashMap<String, ValueWrapper> {
 		for(long i = 0; i < numIter; ++i) {
 			LOG.debug("Putting a double array of size 50MB.");
 			cache.put("file_" + i, new double[numDoubleIn50MB]);
-			try {
-				Thread.sleep(100);
-			} catch (InterruptedException e) {}
 		}
 		cache.clear();
 	}
@@ -130,13 +127,13 @@ public class PersistentLRUCache extends LinkedHashMap<String, ValueWrapper> {
 		_prefixFilePath = tmp.getAbsolutePath();
 	}
 	public ValueWrapper put(String key, double[] value) throws FileNotFoundException, IOException {
-		return putImplm(key, new ValueWrapper(new DataWrapper(key, value, this), isInReadOnlyMode), value.length*Double.BYTES);
+		return putImplm(key, new ValueWrapper(new DataWrapper(key, value, this)), value.length*Double.BYTES);
 	}
 	public ValueWrapper put(String key, float[] value) throws FileNotFoundException, IOException {
-		return putImplm(key, new ValueWrapper(new DataWrapper(key, value, this), isInReadOnlyMode), value.length*Float.BYTES);
+		return putImplm(key, new ValueWrapper(new DataWrapper(key, value, this)), value.length*Float.BYTES);
 	}
 	public ValueWrapper put(String key, MatrixBlock value) throws FileNotFoundException, IOException {
-		return putImplm(key, new ValueWrapper(new DataWrapper(key, value, this), isInReadOnlyMode), value.getInMemorySize());
+		return putImplm(key, new ValueWrapper(new DataWrapper(key, value, this)), value.getInMemorySize());
 	}
 	
 	private ValueWrapper putImplm(String key, ValueWrapper value, long sizeInBytes) throws FileNotFoundException, IOException {
@@ -209,7 +206,7 @@ public class PersistentLRUCache extends LinkedHashMap<String, ValueWrapper> {
     }
 	
 	float [] tmp = new float[0];
-	static String dummyKey = "RAND_KEY_" + Math.abs(_rand.nextLong()) + "_" + Math.abs(_rand.nextLong());
+	String dummyKey = "RAND_KEY_" + Math.abs(_rand.nextLong()) + "_" + Math.abs(_rand.nextLong());
 	void ensureCapacity(long newNumBytes) throws FileNotFoundException, IOException {
 		if(newNumBytes > _maxNumBytes) {
 			throw new DMLRuntimeException("Exceeds maximum capacity. Cannot put a value of size " + newNumBytes + 
@@ -220,7 +217,7 @@ public class PersistentLRUCache extends LinkedHashMap<String, ValueWrapper> {
 			synchronized(this) {
 				if(LOG.isDebugEnabled())
 					LOG.debug("The required capacity (" + newCapacity + ") is greater than max capacity:" + _maxNumBytes);
-				ValueWrapper dummyValue = new ValueWrapper(new DataWrapper(dummyKey, tmp, this), isInReadOnlyMode);
+				ValueWrapper dummyValue = new ValueWrapper(new DataWrapper(dummyKey, tmp, this));
 				int maxIter = size();
 				while(_currentNumBytes.get() > _maxNumBytes && maxIter > 0) {
 					super.put(dummyKey, dummyValue); // This will invoke removeEldestEntry, which will set _eldest
@@ -351,13 +348,17 @@ class DataWrapper {
 		_mo = value;
 		_cache = cache;
 	}
+	@Override
+	protected void finalize() throws Throwable {
+		super.finalize();
+		write(true);
+	}
 	
-	public synchronized void write(boolean forceAggresiveWrites) throws FileNotFoundException, IOException {
-		if(_key.equals(PersistentLRUCache.dummyKey))
+	public synchronized void write(boolean isBeingGarbageCollected) throws FileNotFoundException, IOException {
+		if(_key.equals(_cache.dummyKey))
 			return;
-		
-		// Prepare for writing
 		_cache.makeRecent(_key); // Make it recent.
+		
 		if(_dArr != null || _fArr != null || _mb != null || _mo != null) {
 			_cache._currentNumBytes.addAndGet(-getSize());
 		}
@@ -365,16 +366,14 @@ class DataWrapper {
 		if(!_cache.isInReadOnlyMode) {
 			String debugSuffix = null;
 			if(PersistentLRUCache.LOG.isDebugEnabled()) {
-				if(forceAggresiveWrites)
-					debugSuffix = " (aggressively written).";
+				if(isBeingGarbageCollected)
+					debugSuffix = " (is being garbage collected).";
 				else
 					debugSuffix = " (capacity exceeded).";
 			}
 			
 			if(_dArr != null) {
-				File file = new File(_cache.getFilePath(_key));
-				file.deleteOnExit();
-				try (ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(file))) {
+				try (ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(_cache.getFilePath(_key)))) {
 					os.writeInt(_dArr.length);
 					for(int i = 0; i < _dArr.length; i++) {
 						os.writeDouble(_dArr[i]);
@@ -385,9 +384,7 @@ class DataWrapper {
 					PersistentLRUCache.LOG.debug("Writing value (double[] of size " + getSize() + " bytes) for the key " + _key + " to disk" + debugSuffix);
 			}
 			else if(_fArr != null) {
-				File file = new File(_cache.getFilePath(_key));
-				file.deleteOnExit();
-				try (ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(file))) {
+				try (ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(_cache.getFilePath(_key)))) {
 					os.writeInt(_fArr.length);
 					for(int i = 0; i < _fArr.length; i++) {
 						os.writeFloat(_fArr[i]);
@@ -398,13 +395,12 @@ class DataWrapper {
 					PersistentLRUCache.LOG.debug("Writing value (float[] of size " + getSize() + " bytes) for the key " + _key + " to disk" + debugSuffix);
 			}
 			else if(_mb != null) {
-				File file = new File(_cache.getFilePath(_key));
-				file.deleteOnExit();
-				try(FastBufferedDataOutputStream os = new FastBufferedDataOutputStream(new ObjectOutputStream(new FileOutputStream(file)))) {
+				try(FastBufferedDataOutputStream os = new FastBufferedDataOutputStream(new ObjectOutputStream(new FileOutputStream(_cache.getFilePath(_key))))) {
 					os.writeLong(_mb.getInMemorySize());
 					_mb.write(os);
 				}
 				_cache.persistedKeys.add(_key);
+				System.err.println("Writing value (MatrixBlock of size " + getSize() + " bytes) for the key " + _key + " to disk" + debugSuffix);
 				if(PersistentLRUCache.LOG.isDebugEnabled())
 					PersistentLRUCache.LOG.debug("Writing value (MatrixBlock of size " + getSize() + " bytes) for the key " + _key + " to disk" + debugSuffix);
 			}
@@ -513,63 +509,46 @@ class DataWrapper {
 // Internal helper class
 class ValueWrapper {
 	final Object _lock;
-	final boolean _isInReadOnlyMode;
-	private SoftReference<DataWrapper> _softRef;
+	private SoftReference<DataWrapper> _ref;
 	long _rlen;
 	long _clen;
 	long _nnz;
 	
-	ValueWrapper(DataWrapper data, boolean isInReadOnlyMode) {
+	ValueWrapper(DataWrapper _data) {
 		_lock = new Object();
-		_isInReadOnlyMode = isInReadOnlyMode;
-		boolean isDummyValue = (data._key == PersistentLRUCache.dummyKey);
-		if(!_isInReadOnlyMode && !isDummyValue) {
-			// Aggressive write to disk when the cache is used in the write-mode.
-			// This avoids the need to depend on finalize to perform writing.
-			Thread t = new Thread() {
-			    public void run() {
-			    	try {
-			    		data.write(true);
-					} catch (IOException e) {
-						throw new DMLRuntimeException("Error occured while aggressively writing the value to disk.", e);
-					}
-			    }
-			};
-			t.start();
-		}
-		_softRef = new SoftReference<>(data);
-		if(data._mb != null) {
-			_rlen = data._mb.getNumRows();
-			_clen = data._mb.getNumColumns();
-			_nnz = data._mb.getNonZeros();
+		_ref = new SoftReference<>(_data);
+		if(_data._mb != null) {
+			_rlen = _data._mb.getNumRows();
+			_clen = _data._mb.getNumColumns();
+			_nnz = _data._mb.getNonZeros();
 		}
 	}
-	void update(DataWrapper data) {
-		_softRef = new SoftReference<>(data);
-		if(data._mb != null) {
-			_rlen = data._mb.getNumRows();
-			_clen = data._mb.getNumColumns();
-			_nnz = data._mb.getNonZeros();
+	void update(DataWrapper _data) {
+		_ref = new SoftReference<>(_data);
+		if(_data._mb != null) {
+			_rlen = _data._mb.getNumRows();
+			_clen = _data._mb.getNumColumns();
+			_nnz = _data._mb.getNonZeros();
 		}
 	}
 	boolean isAvailable() {
-		DataWrapper data = _softRef.get();
+		DataWrapper data = _ref.get();
 		return data != null && data.isAvailable();
 	}
 	DataWrapper get() {
-		return _softRef.get();
+		return _ref.get();
 	}
 	long getSize() {
-		DataWrapper data = _softRef.get();
+		DataWrapper data = _ref.get();
 		if(data != null) 
 			return data.getSize();
 		else
 			return 0;
 	}
 	void remove() {
-		DataWrapper data = _softRef.get();
+		DataWrapper data = _ref.get();
 		if(data != null) {
 			data.remove();
 		}
 	}
-}
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/BasicBatchingScheduler.scala b/src/main/scala/org/apache/sysml/api/ml/serving/BasicBatchingScheduler.scala
new file mode 100644
index 0000000..6f19cce
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/BasicBatchingScheduler.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+import java.util.concurrent.{ConcurrentHashMap, CountDownLatch}
+
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
+import scala.math.min
+
+object BasicBatchingScheduler extends BatchingScheduler {
+
+    override def start(numCores: Int, cpuMemoryBudgetInBytes: Long, gpus: String): Unit = {
+        LOG.info(s"Starting Basic Batching Scheduler with: ${numCores} CPUs and ${gpus} GPUs")
+        super.start(numCores, cpuMemoryBudgetInBytes, gpus)
+    }
+
+    /**
+      * Returns a list of requests to execute. If the list contains more than one element, they will be batched
+      * by the executor. Returns an empty list when there are no models to be scheduled.
+      * @param executor an Executor instance
+      * @return a list of model requests to process
+      */
+    override def schedule(executor: JmlcExecutor) : Array[SchedulingRequest] = {
+        var ret = Array[SchedulingRequest]()
+        val execType = executor.getExecType
+        dummyResponse.synchronized {
+            val schedulableModels = getSchedulableModels(execType)
+            if (schedulableModels.nonEmpty) {
+                val (nextModel, nextBatchSize) = getNextModelAndBatchSize(schedulableModels, execType)
+                for (_ <- 0 until nextBatchSize) {
+                    val next = modelQueues.get(nextModel).poll()
+                    assert(next != null, "Something is wrong. Next model should not be null")
+                    ret :+= next
+                }
+            }
+        }
+        ret
+    }
+
+    /**
+      * Helper method which gets the next model to schedule and the optimal batchsize
+      * @param models A list of models to schedule
+      * @return The model to schedule next
+      */
+    def getNextModelAndBatchSize(models : Iterable[String], execType: String) : (String, Int) = {
+        val nextModel = models.map(m =>
+            (getOptimalBatchSize(m, execType)*getExpectedExecutionTime(m), m)).minBy(x => x._1)._2
+
+        val nextBatchSize = min(modelQueues.get(nextModel).size(),
+            getOptimalBatchSize(nextModel, execType))
+        (nextModel, nextBatchSize)
+    }
+
+    /**
+      * Enqueues a request for processing. The scheduler will read from these queues to determine which
+      * models to execute next
+      * @param request A PredictionRequest object containing the data for which a prediction is desired
+      * @param model The model object for which prediction
+      * @return
+      */
+    override private[serving] def enqueue(request: PredictionRequest, model: Model): Future[PredictionResponse] = Future {
+        val statistics = if (_statistics) RequestStatistics() else null
+        val schedulingRequest = SchedulingRequest(
+            request, model, new CountDownLatch(1), System.nanoTime(), null, statistics)
+        statistics.queueSize = modelQueues.get(model.name).size
+        modelQueues.get(model.name).add(schedulingRequest)
+        counter += 1
+        try {
+            schedulingRequest.latch.await(timeout.length, timeout.unit)
+            schedulingRequest.response
+        } catch {
+            case e : scala.concurrent.TimeoutException => dummyResponse
+        }
+    }
+
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/BatchingScheduler.scala b/src/main/scala/org/apache/sysml/api/ml/serving/BatchingScheduler.scala
new file mode 100644
index 0000000..e62f4ef
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/BatchingScheduler.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.LongAdder
+
+import scala.math.{floor, max}
+
+trait BatchingScheduler extends Scheduler {
+
+    val modelBatchSizes = new ConcurrentHashMap[String, ConcurrentHashMap[String,Int]]()
+    val expectedExecutionTimes = new ConcurrentHashMap[String, (LongAdder, LongAdder)]()
+
+    def getOptimalBatchSize(model : String, execType: String) : Int = {
+        modelBatchSizes.putIfAbsent(execType, new ConcurrentHashMap[String,Int]())
+        modelBatchSizes.get(execType).putIfAbsent(model, 2)
+        modelBatchSizes.get(execType).get(model)
+    }
+
+    override def onCompleteCallback(model: String,
+                                    latency: Double,
+                                    batchSize: Int,
+                                    execType: String,
+                                    execTime: Long): Unit = {
+        if (batchSize > 1) {
+            val latencyObjective = latencyObjectives.get(model)
+            val prevSize = modelBatchSizes.get(execType).get(model)
+            val decreaseSize = if (prevSize > 10) max(floor(prevSize * 0.90).toInt, 1) else prevSize - 1
+            modelBatchSizes.get(execType).put(model,
+                if (latency < latencyObjective.toNanos) prevSize + 1 else decreaseSize)
+
+            // update expected execution times. For now we just assume this is a simple average
+            val execTimeData = expectedExecutionTimes.get(model)
+            execTimeData._1.add(execTime / batchSize)
+            execTimeData._2.increment()
+        }
+    }
+
+    def getExpectedExecutionTime(model: String) : Long = {
+        expectedExecutionTimes.putIfAbsent(model, (new LongAdder(), new LongAdder()))
+        val execTime = expectedExecutionTimes.get(model)
+        val totalNumRequests = execTime._2.longValue()
+        if  (totalNumRequests > 0) execTime._1.longValue() / execTime._2.longValue() else 0
+    }
+
+    /**
+      * Gets a list of models that are eligible to be run. A model is eligible to be run if it
+      * has a greater number of requests enqueued than its optimal batch size.
+      * @return A list of models which may be scheduled
+      */
+    def getSchedulableModels(execType: String) : Set[String] = {
+        var batchableModels = Set[String]()
+        var shortFuse = Set[String]()
+        val keyIterator = modelQueues.keys()
+        while (keyIterator.hasMoreElements) {
+            val name = keyIterator.nextElement()
+            val qsize = modelQueues.get(name).size()
+            if (qsize > 0) {
+                val nextRequest = modelQueues.get(name).peek()
+                assert(nextRequest != null, "Something is wrong. Next request should not be null")
+
+                if (checkShortFuse(nextRequest, qsize)) {
+                    LOG.info("Model: " + name + " is near violating threshold. Scheduling immediately.")
+                    shortFuse += name
+                }
+
+                if (qsize >= getOptimalBatchSize(name, execType)) {
+                    batchableModels += name
+                }
+            }
+        }
+
+        if (shortFuse.nonEmpty) shortFuse else batchableModels
+    }
+
+    /**
+      * Returns a boolean value if it would violate the latency threshold to execute the current number of models
+      */
+    def checkShortFuse(request: SchedulingRequest, numRequests: Int) : Boolean = {
+        val elapsed = System.nanoTime() - request.receivedTime
+        (elapsed + 1.1*numRequests*getExpectedExecutionTime(request.model.name)) > request.model.latencyObjective.toNanos
+    }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/BatchingUtils.scala b/src/main/scala/org/apache/sysml/api/ml/serving/BatchingUtils.scala
new file mode 100644
index 0000000..ca28e7f
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/BatchingUtils.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+import org.apache.sysml.runtime.matrix.data.MatrixBlock
+
+object BatchingUtils {
+        def batchRequests(requests: Array[SchedulingRequest]) : MatrixBlock = {
+            if (requests.length == 1) {
+                return requests(0).request.data
+            }
+            val ncol = requests(0).request.data.getNumColumns
+            val res = new MatrixBlock(requests.length, ncol, -1).allocateDenseBlock()
+            val doubles = res.getDenseBlockValues
+            var start = 0
+            for (req <- requests) {
+                System.arraycopy(req.request.data.getDenseBlockValues, 0, doubles, start, ncol)
+                start += ncol
+            }
+            res.setNonZeros(-1)
+            res
+        }
+
+        def unbatchRequests(requests: Array[SchedulingRequest],
+                            batchedResults: MatrixBlock) : Array[PredictionResponse] = {
+            var responses = Array[PredictionResponse]()
+            val start = 0
+            for (req <- requests) {
+                val unbatchStart = System.nanoTime()
+                val resp = PredictionResponse(batchedResults.slice(
+                    start, (start + req.request.requestSize)-1), 
+                    batchedResults.getNumRows, req.statistics)
+                val unbatchingTime = System.nanoTime() - unbatchStart
+                if (req.statistics != null)
+                    req.statistics.unbatchingTime = unbatchingTime
+
+                responses :+= resp
+            }
+
+            responses
+        }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/Executor.scala b/src/main/scala/org/apache/sysml/api/ml/serving/Executor.scala
new file mode 100644
index 0000000..c353e07
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/Executor.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.ml.serving
+import java.util.concurrent.PriorityBlockingQueue
+import java.util.concurrent.atomic.LongAdder
+
+import org.apache.commons.logging.{Log, LogFactory}
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext
+
+
+case class Batch(size: Int, expectedTime: Long, priority: Double, modelName: String) extends Comparable[Batch] {
+    override def compareTo(that: Batch): Int = {
+        this.priority.compareTo(that.priority)
+    }
+}
+
+class BatchQueue(execType: String, name: String) extends PriorityBlockingQueue[Batch] {
+    val LOG: Log = LogFactory.getLog(classOf[BatchQueue].getName)
+    private val expectedExecutionTime = new LongAdder()
+    private var prevFirstRequest= Map[String, SchedulingRequest]()
+
+    def getName : String = { name }
+
+    def updatePrevRequest(name: String, request: SchedulingRequest) : Unit = {
+        prevFirstRequest += (name -> request)
+    }
+
+    def getPrevRequest(name: String) : SchedulingRequest = { prevFirstRequest.getOrElse(name, null) }
+
+    def enqueue(batch: Batch) : Unit = {
+        LOG.debug("Enqueuing onto: " + getName)
+        synchronized {
+            this.add(batch)
+            expectedExecutionTime.add(batch.expectedTime)
+        }
+    }
+
+    def dequeue() : Batch = {
+        if (this.isEmpty)
+            return Batch(-1, -1, -1, "NO NAME")
+        synchronized {
+            val nextBatch = this.poll()
+            expectedExecutionTime.add(-1*nextBatch.expectedTime)
+            return nextBatch
+        }
+    }
+
+    def getExpectedExecutionTime : Long = { expectedExecutionTime.longValue() }
+
+    def getExecType : String = { execType }
+}
+
+class JmlcExecutor(scheduler: Scheduler, execType: String, name: String, gCtx: GPUContext) extends Runnable {
+    @volatile protected var _shouldShutdown: Boolean = false
+    val LOG: Log = LogFactory.getLog(classOf[JmlcExecutor].getName)
+    var prevModel = ""
+
+    def shutdown(): Unit = {
+        _shouldShutdown = true
+    }
+
+    def getExecType: String = { execType }
+
+    def getName: String = { name }
+
+    def run(): Unit = {
+        Thread.sleep(1000)
+        while (!_shouldShutdown) {
+            val requests = scheduler.schedule(this)
+            if (requests.nonEmpty) {
+                val responses = execute(requests)
+                for ((req, resp) <- requests zip responses) {
+                    req.response = resp
+                    req.latch.countDown()
+                }
+            }
+        }
+    }
+
+    def execute(requests: Array[SchedulingRequest]): Array[PredictionResponse] = {
+        var responses = Array[PredictionResponse]()
+        if (requests.nonEmpty) {
+            try {
+                val start = System.nanoTime()
+                val batchedMatrixData = BatchingUtils.batchRequests(requests)
+                val batchingTime = System.nanoTime() - start
+                val req = requests(0)
+                LOG.info("Executing: " + req.model.name + " with batch size: " + batchedMatrixData.getNumRows + " on " + name)
+                val modelAcquireStart = System.nanoTime()
+                val script = scheduler.modelManager.acquire(req.model.name, this)
+                script.setName(this.getName)
+                val modelAcquireTime = System.nanoTime() - modelAcquireStart
+                script.setMatrix(req.model.inputVarName, batchedMatrixData, false)
+                val execStart = System.nanoTime()
+                val res = script.executeScript().getMatrixBlock(req.model.outputVarName)
+                val execTime = System.nanoTime() - execStart
+                responses = BatchingUtils.unbatchRequests(requests, res)
+
+                val modelReleaseStart = System.nanoTime()
+                scheduler.modelManager.release(req.model.name)
+                scheduler.modelManager.releaseMemory(req.memUse)
+                val modelReleaseTime = System.nanoTime() - modelReleaseStart
+                scheduler.onCompleteCallback(req.model.name,
+                                             System.nanoTime() - req.receivedTime,
+                                             requests.length,
+                                             execType, System.nanoTime() - start)
+                if (req.statistics != null)
+                    setStatistics(requests, start, batchingTime, execTime, modelAcquireTime, modelReleaseTime)
+                if (prevModel.nonEmpty)
+                    scheduler.modelManager.unsetModelLocality(prevModel, this)
+                scheduler.modelManager.setModelLocality(req.model.name, this)
+                prevModel = req.model.name
+
+                LOG.info("Done executing request for: " + req.model.name + " on " + name)
+            } catch {
+                case e: Exception => println("AN ERROR OCCURRED: " + e.getMessage + e.printStackTrace())
+            }
+        }
+        responses
+    }
+
+    def setStatistics(requests: Array[SchedulingRequest],
+                      processingStartTime: Long,
+                      batchingTime: Long,
+                      execTime: Long,
+                      modelAcquireTime: Long,
+                      modelReleaseTime: Long): Unit = {
+        for (req <- requests) {
+            req.statistics.batchingTime = batchingTime
+            req.statistics.execType = getExecType
+            req.statistics.batchSize = requests.length
+            req.statistics.queueWaitTime = processingStartTime - req.receivedTime
+            req.statistics.execTime = execTime
+            req.statistics.modelAcquireTime = modelAcquireTime
+            req.statistics.modelReleaseTime = modelReleaseTime
+        }
+    }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/LocalityAwareScheduler.scala b/src/main/scala/org/apache/sysml/api/ml/serving/LocalityAwareScheduler.scala
new file mode 100644
index 0000000..61fc84f
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/LocalityAwareScheduler.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+import java.util.concurrent.{ConcurrentHashMap, CountDownLatch}
+
+import org.apache.commons.logging.{Log, LogFactory}
+
+import scala.concurrent.Future
+import scala.math.min
+
+object ExecutorQueueManager extends Runnable {
+    val LOG: Log = LogFactory.getLog(ExecutorQueueManager.getClass.getName)
+    var _shutDown = false
+    var _scheduler = LocalityAwareScheduler
+    def shutdown(): Unit = { _shutDown = true }
+
+    override def run() : Unit = {
+        while (!_shutDown) {
+            _scheduler.dummyResponse.synchronized {
+                val schedulableModels = _scheduler.executorTypes.map(
+                    x => _scheduler.getSchedulableModels(x)).reduce(_ union _)
+                if (schedulableModels.nonEmpty) {
+                    for (m <- schedulableModels) {
+                        // every request batch can go to up to three queues
+
+                        // 1. Every batch goes to the global disk queue since the model might get evicted
+                        val diskQueues = _scheduler.executorTypes.map(x => _scheduler.globalDiskQueues.get(x))
+
+                        // 2. If the model is cached in memory, then also put it on the cache queue
+                        var cacheQueues = Array[BatchQueue]()
+                        if (_scheduler.modelManager.isCached(m))
+                            cacheQueues = _scheduler.executorTypes.map(x => _scheduler.globalCacheQueues.get(x))
+
+                        // 3. If the model is local to an executor, then put it on the lowest utilizaiton queue
+                        val localExecutionQueues = getLocalExecutionQueues(m)
+                        val localQueue = if (localExecutionQueues.nonEmpty)
+                            Array[BatchQueue](localExecutionQueues.minBy(x => x.getExpectedExecutionTime))
+                        else Array[BatchQueue]()
+
+                        val queues = diskQueues ++ cacheQueues ++ localQueue
+                        val nextRequest = _scheduler.modelQueues.get(m).peek()
+                        queues.foreach ( queue => {
+                            val qsize = _scheduler.modelQueues.get(m).size()
+                            if (nextRequest ne queue.getPrevRequest(m)) {
+                                val nextBatchSize = min(qsize, _scheduler.getOptimalBatchSize(m, queue.getExecType))
+                                assert(nextBatchSize > 0, "An error occurred - batch size should not be zero")
+                                LOG.debug("Enqueuing: " + nextBatchSize + " for: " + m + " onto: " + queue.getName)
+                                val nextBatch = Batch(
+                                    nextBatchSize, nextBatchSize*_scheduler.getExpectedExecutionTime(m),
+                                    nextRequest.receivedTime - System.nanoTime(), nextRequest.model.name)
+                                queue.enqueue(nextBatch)
+                                LOG.debug("Batch enqueued onto: " + queue.getName)
+                            }
+                            queue.updatePrevRequest(m, nextRequest) } )
+                        }
+                    }
+                }
+            }
+        }
+
+    def getLocalExecutionQueues(model: String) : Array[BatchQueue] = {
+        val execs = _scheduler.modelManager.getModelLocality(model)
+        var queues = Array[BatchQueue]()
+        if (execs == null)
+            return queues
+
+        _scheduler.modelManager.synchronized({
+            for (ix <- 0 until execs.size()) { _scheduler.executorQueues.get(execs.get(ix)) }
+        })
+
+        queues
+    }
+}
+
+object ExecMode extends Enumeration {
+    type MODE = Value
+    val LOCAL, GLOBAL_MEM, GLOBAL_DISK = Value
+}
+
+object LocalityAwareScheduler extends BatchingScheduler {
+    var queueManager : Thread = _
+
+    val globalCacheQueues = new ConcurrentHashMap[String, BatchQueue]()
+    val globalDiskQueues = new ConcurrentHashMap[String, BatchQueue]()
+
+    override def start(numCores: Int, cpuMemoryBudgetInBytes: Long, gpus: String): Unit = {
+        super.start(numCores, cpuMemoryBudgetInBytes, gpus)
+
+        executorTypes.foreach ( x => {
+            globalCacheQueues.putIfAbsent(x, new BatchQueue(x, x + "-CACHE"))
+            globalDiskQueues.putIfAbsent(x, new BatchQueue(x, x + "-DISK"))
+        } )
+
+        queueManager = new Thread(ExecutorQueueManager)
+        queueManager.start()
+    }
+
+    override def addModel(model: Model): Unit = {
+        super.addModel(model)
+    }
+
+    override def schedule(executor: JmlcExecutor) : Array[SchedulingRequest] = {
+        var ret = Array[SchedulingRequest]()
+        val localQueue = executorQueues.get(executor)
+        val globalDiskQueue = globalDiskQueues.get(executor.getExecType)
+        val globalMemQueue = globalCacheQueues.get(executor.getExecType)
+        if (localQueue.size() > 0 || globalDiskQueue.size() > 0 || globalMemQueue.size() > 0) {
+            dummyResponse.synchronized {
+                if (localQueue.size() > 0 || globalDiskQueue.size() > 0 || globalMemQueue.size() > 0) {
+                    LOG.debug("Begin scheduling for executor: " + executor.getName)
+                    val execMode = Array[(BatchQueue, ExecMode.MODE)](
+                        (localQueue, ExecMode.LOCAL),
+                        (globalDiskQueue, ExecMode.GLOBAL_DISK),
+                        (globalMemQueue, ExecMode.GLOBAL_MEM)
+                    ).filter(x => x._1.size() > 0).maxBy(x => x._1.getExpectedExecutionTime)._2
+
+                    val batch = execMode match {
+                        case ExecMode.LOCAL => localQueue.peek()
+                        case ExecMode.GLOBAL_MEM => globalMemQueue.peek()
+                        case ExecMode.GLOBAL_DISK => globalDiskQueue.peek()
+                    }
+                    assert(batch != null, "Something is wrong. Batch should not be null!")
+
+                    // now we need to ask the resource manager if there's enough memory to execute the batch
+                    val model = modelManager.get(batch.modelName)
+
+                    // If there's enough memory we can actually remove the requests from the queue and
+                    // submit them for processing
+                    val mqueue = modelQueues.get(batch.modelName)
+                    val numToDequeue = min(batch.size, mqueue.size())
+
+                    // if this value is zero there are no more requests and the batch is stale
+                    if (numToDequeue == 0) {
+                        execMode match {
+                            case ExecMode.LOCAL => localQueue.poll()
+                            case ExecMode.GLOBAL_DISK => globalDiskQueue.poll()
+                            case ExecMode.GLOBAL_MEM => globalMemQueue.poll()
+                        }
+                    } else {
+                        val memReceived = modelManager.tryAllocMem(model.name, batch.size)
+                        if (memReceived < 0) {
+                            return ret
+                        }
+
+                        // now we need to actually remove the request from the queue since it's going to be processed
+                        execMode match {
+                            case ExecMode.LOCAL => localQueue.poll()
+                            case ExecMode.GLOBAL_DISK => globalDiskQueue.poll()
+                            case ExecMode.GLOBAL_MEM => globalMemQueue.poll()
+                        }
+
+                        // now we can actually take the original requests out of the model queues
+                        LOG.debug("Scheduling: " + numToDequeue + " for " + batch.modelName + " on " + executor.getName)
+                        for (_ <- 0 until numToDequeue) {
+                            val nextRequest = mqueue.poll()
+                            assert(nextRequest != null, "Something is wrong - request should not be null!")
+
+                            nextRequest.memUse = memReceived
+                            nextRequest.statistics.execMode = execMode match {
+                                case ExecMode.LOCAL => 0
+                                case ExecMode.GLOBAL_MEM => 1
+                                case ExecMode.GLOBAL_DISK => 2
+                                case _ => -1
+                            }
+                            ret :+= nextRequest
+                        }
+                        LOG.debug("Done scheduling on: " + executor.getName)
+                    }
+                }
+            }
+        }
+        ret
+    }
+
+    /**
+      * Enqueues a request for processing. The scheduler will read from these queues to determine which
+      * models to execute next
+      * @param request A PredictionRequest object containing the data for which a prediction is desired
+      * @param model The model object for which prediction
+      * @return
+      */
+    override private[serving] def enqueue(request: PredictionRequest, model: Model): Future[PredictionResponse] = Future {
+        val statistics = if (_statistics) RequestStatistics() else null
+        val schedulingRequest = SchedulingRequest(
+            request, model, new CountDownLatch(1), System.nanoTime(), null, statistics)
+
+        if (_statistics) {
+            statistics.queueSize = modelQueues.get(model.name).size
+            statistics.preprocWaitTime = System.nanoTime() - request.receivedTime
+        }
+
+        modelQueues.get(model.name).add(schedulingRequest)
+
+        try {
+            schedulingRequest.latch.await(timeout.length, timeout.unit)
+            schedulingRequest.response
+        } catch {
+            case _ : scala.concurrent.TimeoutException => dummyResponse
+        }
+    }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/ModelManager.scala b/src/main/scala/org/apache/sysml/api/ml/serving/ModelManager.scala
new file mode 100644
index 0000000..b67d367
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/ModelManager.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+import java.util
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.LongAdder
+
+import org.apache.commons.logging.{Log, LogFactory}
+import org.apache.sysml.api.jmlc.{Connection, PreparedScript}
+import org.apache.sysml.runtime.matrix.data.MatrixBlock
+import org.apache.sysml.utils.PersistentLRUCache
+
+trait ModelManager {
+    val LOG: Log = LogFactory.getLog(classOf[ModelManager].getName)
+    var modelLocality = new ConcurrentHashMap[String, util.ArrayList[JmlcExecutor]]()
+    val conn: Connection = new Connection()
+    val availableMemory = new LongAdder
+    var totalMemory = 0L
+    var cleanupEnabled = true
+    var memCheckEnabled = true
+    var models: Map[String, Model] = Map()
+
+    def setAvailableMemory(memBytes: Long) : Unit = {
+        LOG.info("Setting total memory to: " + memBytes + " bytes")
+        totalMemory = memBytes
+        availableMemory.reset()
+        availableMemory.add(memBytes)
+    }
+
+    def getAvailableMemory : Long = { availableMemory.longValue() }
+
+    def acquireMemory(bytes: Long) : Long = {
+        // if memory checking is not enabled just always say they get the memory
+        if (!memCheckEnabled || bytes == 0)
+            return bytes
+        LOG.debug("Requested: " + bytes)
+
+        // otherwise check to see if there is enough memory to meet the request
+        if (bytes <= availableMemory.longValue()) {
+            availableMemory.add(-1 * bytes)
+            LOG.debug("Granted: " + bytes + "/" + availableMemory.longValue())
+            return bytes
+        }
+        // not enough memory available :(
+
+        LOG.debug("Insufficient memory. Request was not granted")
+        -1
+    }
+
+    def releaseMemory(bytes: Long) : Unit = {
+        if (bytes > 0) {
+            LOG.debug("Releasing: " + bytes)
+            availableMemory.add(bytes)
+            LOG.debug("Available memory is now: " + availableMemory.longValue())
+        }
+    }
+
+    def setModelLocality(model: String, exec: JmlcExecutor) : Unit = {
+        this.synchronized({
+            modelLocality.putIfAbsent(model, new util.ArrayList[JmlcExecutor]())
+            modelLocality.get(model).add(exec)
+        })
+    }
+
+    def unsetModelLocality(model: String, exec: JmlcExecutor) : Unit = {
+        this.synchronized({ modelLocality.get(model).remove(exec) })
+    }
+
+    def getModelLocality(model: String) : util.ArrayList[JmlcExecutor] = { modelLocality.get(model) }
+
+    def isModelLocal(model: String, exec: JmlcExecutor) : Boolean = { getModelLocality(model).contains(exec) }
+
+    def disableCleanup() : Unit = { cleanupEnabled = false }
+
+    def disableMemcheck() : Unit = { memCheckEnabled = false }
+
+    def put(model: Model): Unit
+
+    def get(name: String): Model
+
+    def putWeight(name: String, weight: MatrixBlock) : Unit
+
+    def acquire(name: String, executor: JmlcExecutor) : PreparedScript
+
+    def release(name: String) : Unit
+}
+
+object ReferenceCountedModelManager extends ModelManager {
+    var modelRefCounts: Map[String,LongAdder] = Map()
+    var weightCache : PersistentLRUCache = _
+
+    override def setAvailableMemory(maxBytes: Long) : Unit = {
+        super.setAvailableMemory(maxBytes)
+        weightCache = new PersistentLRUCache((0.80*maxBytes).toLong)
+        weightCache.enableReadOnlyMode(true)
+    }
+
+    def tryAllocMem(name: String, batchSize: Int) : Long = {
+        // TODO: More sophisticated memory management
+        val extraMem = (0.5*models(name).weightMem).toLong
+        val weightMem = if (modelRefCounts(name).longValue() > 0) 0L else models(name).weightMem
+        val memReceived = acquireMemory(extraMem + weightMem)
+        if (memReceived < 0) memReceived else extraMem
+    }
+
+    def isCached(name: String) : Boolean = { modelRefCounts(name).longValue() > 0 }
+
+    def acquire(name: String, executor: JmlcExecutor) : PreparedScript = {
+         LOG.debug("Acquiring model: " + name + " Ref count: " + modelRefCounts(name).longValue())
+
+        val execName = if (executor.getExecType == "GPU") executor.getName else executor.getExecType
+        val ps = models(name).script(execName)
+        if (modelRefCounts(name).longValue() > 0 && ps.hasPinnedData) {
+            modelRefCounts(name).increment()
+            return ps.clone(false)
+        }
+
+        // otherwise we need to re-pin the weights, possibly reading them from disk
+        val model = models(name)
+        model.synchronized {
+            LOG.debug("Pinning weights for: " + name)
+            model.weightFiles.foreach(x => ps.setMatrix(x._1, weightCache.getAsMatrixBlock(x._2), true))
+            modelRefCounts(name).increment()
+        }
+        LOG.debug("Done acquiring model: " + name)
+        ps.clone(false)
+    }
+
+    override def disableCleanup(): Unit = {
+        super.disableCleanup()
+        LOG.debug("Cleanup is disabled")
+    }
+
+    def release(name: String) : Unit = {
+        modelRefCounts(name).decrement()
+        releaseMemory(models(name).weightMem)
+
+        LOG.debug("Releasing model: " + name + " Ref count: " + modelRefCounts(name).longValue())
+        if (modelRefCounts(name).longValue() == 0) {
+            models(name).script.synchronized {
+                if (modelRefCounts(name).longValue() == 0) {
+                    models(name).script.foreach { x => x._2.clearPinnedData() }
+                }
+            }
+        }
+    }
+
+    def put(model: Model) : Unit = {
+        models += (model.name -> model)
+        modelRefCounts += (model.name -> new LongAdder())
+    }
+
+    def putWeight(name: String, weight: MatrixBlock) : Unit = {
+        weightCache.put(name, weight)
+    }
+
+    def get(name: String) : Model = { models(name) }
+
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/NonBatchingScheduler.scala b/src/main/scala/org/apache/sysml/api/ml/serving/NonBatchingScheduler.scala
new file mode 100644
index 0000000..44ff26f
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/NonBatchingScheduler.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.atomic.LongAdder
+
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
+
+object NonBatchingScheduler extends Scheduler {
+
+    override def start(numCores: Int, cpuMemoryBudgetInBytes: Long, gpus: String): Unit = {
+        LOG.info(s"Starting Non Batching Scheduler with: ${numCores} CPUs and ${gpus} GPUs")
+        super.start(numCores, cpuMemoryBudgetInBytes, gpus)
+    }
+
+    override def schedule(executor: JmlcExecutor): Array[SchedulingRequest] = {
+        var ret = Array[SchedulingRequest]()
+        dummyResponse.synchronized {
+            if (requestQueue.size() > 0) {
+                val request = requestQueue.poll()
+                ret :+= request
+            }
+        }
+        ret
+    }
+
+    var requestNum = new LongAdder
+    /**
+      * Enqueues a request for processing. The scheduler will read from these queues to determine which
+      * models to execute next
+      * @param request A PredictionRequest object containing the data for which a prediction is desired
+      * @param model The model object for which prediction is desired
+      * @return
+      */
+    override private[serving] def enqueue(request: PredictionRequest, model: Model): Future[PredictionResponse] = Future {
+        val statistics = if (_statistics) RequestStatistics() else null
+        val schedulingRequest = SchedulingRequest(
+            request, model, new CountDownLatch(1), System.nanoTime(), null, statistics)
+        if (_statistics) statistics.queueSize = requestQueue.size()
+        requestQueue.add(schedulingRequest)
+        counter += 1
+        try {
+            schedulingRequest.latch.await(timeout.length, timeout.unit)
+            schedulingRequest.response
+        } catch {
+            case e : scala.concurrent.TimeoutException => dummyResponse
+        }
+    }
+
+    override def onCompleteCallback(model: String, latency: Double, batchSize: Int, execType: String, execTime: Long): Unit = {}
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/PredictionService.scala b/src/main/scala/org/apache/sysml/api/ml/serving/PredictionService.scala
new file mode 100644
index 0000000..f8c2345
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/PredictionService.scala
@@ -0,0 +1,490 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+import java.io.File
+
+import akka.http.scaladsl.server.StandardRoute
+import akka.http.scaladsl.server.Directives._
+import akka.http.scaladsl.model.StatusCodes
+import akka.http.scaladsl.Http
+import akka.actor.ActorSystem
+import akka.stream.ActorMaterializer
+import org.apache.commons.cli.PosixParser
+import com.typesafe.config.ConfigFactory
+
+import scala.concurrent.duration._
+import java.util.HashMap
+
+import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
+import spray.json._
+import java.util.concurrent.atomic.LongAdder
+
+import scala.concurrent.{Await, Future}
+import scala.math.{max, pow}
+import org.apache.sysml.runtime.matrix.data.{MatrixBlock, OutputInfo}
+import org.apache.sysml.parser.DataExpression
+import org.apache.sysml.runtime.io.IOUtilFunctions
+import org.apache.sysml.api.jmlc.Connection
+import org.apache.sysml.api.jmlc.PreparedScript
+import org.apache.sysml.conf.ConfigurationManager
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics
+import org.apache.sysml.runtime.util.DataConverter
+import org.apache.commons.logging.Log
+import org.apache.commons.logging.LogFactory
+
+import scala.concurrent.ExecutionContext
+
+// format: can be file, binary, csv, ijv, jpeg, ...
+
+case class RequestStatistics(var batchSize: Int = -1,
+                             var execTime: Long = -1,
+                             var execType: String = "",
+                             var requestDeserializationTime: Long = -1,
+                             var responseSerializationTime: Long = -1,
+                             var modelAcquireTime: Long = -1,
+                             var modelReleaseTime: Long = -1,
+                             var batchingTime: Long = -1,
+                             var unbatchingTime: Long = -1,
+                             var queueWaitTime: Long = -1,
+                             var queueSize: Int = -1,
+                             var execMode: Int = 0,
+                             var preprocWaitTime: Long = -1)
+case class PredictionRequestExternal(name: String, data: Array[Double], rows: Int, cols: Int)
+case class PredictionResponseExternal(response: Array[Double], rows: Int, cols: Int, statistics: RequestStatistics)
+
+case class AddModelRequest(name: String, dml: String, inputVarName: String,
+                           outputVarName: String, weightsDir: String,
+                           latencyObjective: String, batchSize: Array[Int], memUse: Array[Long])
+
+case class Model(name: String,
+                 script: Map[String,PreparedScript],
+                 inputVarName: String,
+                 outputVarName: String,
+                 latencyObjective: Duration,
+                 weightFiles: Map[String, String],
+                 coeffs: (Double, Double),
+                 weightMem: Long)
+case class PredictionRequest(data : MatrixBlock, modelName : String, requestSize : Int, receivedTime : Long)
+case class PredictionResponse(response: MatrixBlock, batchSize: Int, statistics: RequestStatistics)
+case class MatrixBlockContainer(numRows: Long, numCols: Long, nnz: Long, sum: Double, data: MatrixBlock)
+
+trait PredictionJsonProtocol extends SprayJsonSupport with DefaultJsonProtocol {
+    implicit val RequestStatisticsFormat = jsonFormat13(RequestStatistics)
+    implicit val predictionRequestExternalFormat = jsonFormat4(PredictionRequestExternal)
+    implicit val predictionResponseExternalFormat = jsonFormat4(PredictionResponseExternal)
+}
+
+trait AddModelJsonProtocol extends SprayJsonSupport with DefaultJsonProtocol {
+    implicit val AddModelRequetFormat = jsonFormat8(AddModelRequest)
+}
+
+class PredictionService {
+
+}
+
+/*
+Usage:
+1. Compiling a fat jar with maven assembly plugin in our standalone jar created lot of issues. 
+Hence, for time being, we recommend downloading jar using the below script:
+SCALA_VERSION="2.11"
+AKKA_HTTP_VERSION="10.1.3"
+AKKA_VERSION="2.5.14"
+PREFIX="http://central.maven.org/maven2/com/typesafe/akka/"
+JARS=""
+for PKG in actor stream protobuf
+do
+  PKG_NAME="akka-"$PKG"_"$SCALA_VERSION
+  JAR_FILE=$PKG_NAME"-"$AKKA_VERSION".jar"
+  wget $PREFIX$PKG_NAME"/"$AKKA_VERSION"/"$JAR_FILE
+  JARS=$JARS$JAR_FILE":"
+done
+for PKG in http http-core parsing
+do
+  PKG_NAME="akka-"$PKG"_"$SCALA_VERSION
+  JAR_FILE=$PKG_NAME"-"$AKKA_HTTP_VERSION".jar"
+  wget $PREFIX$PKG_NAME"/"$AKKA_HTTP_VERSION"/"$JAR_FILE
+  JARS=$JARS$JAR_FILE":"
+done
+wget http://central.maven.org/maven2/com/typesafe/config/1.3.3/config-1.3.3.jar
+wget http://central.maven.org/maven2/com/typesafe/ssl-config-core_2.11/0.2.4/ssl-config-core_2.11-0.2.4.jar
+wget http://central.maven.org/maven2/org/reactivestreams/reactive-streams/1.0.2/reactive-streams-1.0.2.jar
+wget http://central.maven.org/maven2/org/scala-lang/scala-library/2.11.12/scala-library-2.11.12.jar
+wget http://central.maven.org/maven2/org/scala-lang/scala-parser-combinators/2.11.0-M4/scala-parser-combinators-2.11.0-M4.jar
+wget http://central.maven.org/maven2/commons-cli/commons-cli/1.4/commons-cli-1.4.jar
+wget http://central.maven.org/maven2/com/typesafe/akka/akka-http-spray-json-experimental_2.11/2.4.11.2/akka-http-spray-json-experimental_2.11-2.4.11.2.jar
+wget http://central.maven.org/maven2/io/spray/spray-json_2.11/1.3.2/spray-json_2.11-1.3.2.jar
+JARS=$JARS"config-1.3.3.jar:ssl-config-core_2.11-0.2.4.jar:reactive-streams-1.0.2.jar:commons-cli-1.4.jar:scala-parser-combinators-2.11.0-M4.jar:scala-library-2.11.12.jar:akka-http-spray-json-experimental_2.11-2.4.11.2.jar:spray-json_2.11-1.3.2.jar"
+echo "Include the following jars into the classpath: "$JARS
+
+
+2. Copy SystemML.jar and systemml-1.2.0-SNAPSHOT-extra.jar into the directory where akka jars are placed
+
+3. Start the server:
+java -cp $JARS org.apache.sysml.api.ml.serving.PredictionService -port 9000 -admin_password admin
+
+4. Check the health of the server:
+curl -u admin -XGET localhost:9000/health
+
+5. Perform prediction
+curl -XPOST -H "Content-Type:application/json" -d '{ "inputs":"1,2,3", "format":"csv", "model":"test", "num_input":1 }' localhost:9000/predict
+
+6. Shutdown the server:
+curl -u admin -XGET localhost:9000/shutdown
+
+ */
+
+object PredictionService extends PredictionJsonProtocol with AddModelJsonProtocol {
+    val __DEBUG__ = false
+
+    val LOG = LogFactory.getLog(classOf[PredictionService].getName)
+    val customConf = ConfigFactory.parseString("""
+        akka.http.server.idle-timeout=infinite
+        akka.http.client.idle-timeout=infinite
+        akka.http.host-connection-pool.idle-timeout=infinite
+        akka.http.host-connection-pool.client.idle-timeout=infinite
+        akka.http.server.max-connections=100000
+    """)
+    val basicConf = ConfigFactory.load()
+    val combined = customConf.withFallback(basicConf)
+    implicit val system = ActorSystem("systemml-prediction-service", ConfigFactory.load(combined))
+    implicit val materializer = ActorMaterializer()
+    implicit val executionContext = ExecutionContext.global
+    implicit val timeout = akka.util.Timeout(300.seconds)
+    val userPassword = new HashMap[String, String]()
+    var bindingFuture: Future[Http.ServerBinding] = null
+    var scheduler: Scheduler = null
+    val conn = new Connection()
+    var existantMatrixBlocks = Array[MatrixBlockContainer]()
+
+    def getCommandLineOptions(): org.apache.commons.cli.Options = {
+        val hostOption = new org.apache.commons.cli.Option("ip", true, "IP address")
+        val portOption = new org.apache.commons.cli.Option("port", true, "Port number")
+        val numRequestOption = new org.apache.commons.cli.Option("max_requests", true, "Maximum number of requests")
+        val timeoutOption = new org.apache.commons.cli.Option("timeout", true, "Timeout in milliseconds")
+        val passwdOption = new org.apache.commons.cli.Option("admin_password", true, "Admin password. Default: admin")
+        val helpOption = new org.apache.commons.cli.Option("help", false, "Show usage message")
+        val maxSizeOption = new org.apache.commons.cli.Option("max_bytes", true, "Maximum size of request in bytes")
+        val statisticsOption = new org.apache.commons.cli.Option("statistics", true, "Gather statistics on request execution")
+        val numCpuOption = new org.apache.commons.cli.Option("num_cpus", true, "How many CPUs should be allocated to the prediction service. Default nproc-1")
+        val gpusOption = new org.apache.commons.cli.Option("gpus", true, "GPUs available to this process. Default: 0")
+        val schedulerOption = new org.apache.commons.cli.Option("scheduler", true, "Scheduler implementation to use. Default: locality-aware")
+
+        // Only port is required option
+        portOption.setRequired(true)
+
+        return new org.apache.commons.cli.Options()
+          .addOption(hostOption).addOption(portOption).addOption(numRequestOption)
+          .addOption(passwdOption).addOption(timeoutOption).addOption(helpOption)
+          .addOption(maxSizeOption).addOption(statisticsOption).addOption(numCpuOption)
+          .addOption(gpusOption).addOption(schedulerOption)
+    }
+
+    def main(args: Array[String]): Unit = {
+        // Parse commandline variables:
+        val options = getCommandLineOptions
+        val line = new PosixParser().parse(getCommandLineOptions, args)
+        if (line.hasOption("help")) {
+            new org.apache.commons.cli.HelpFormatter().printHelp("systemml-prediction-service", options)
+            return
+        }
+        userPassword.put("admin", line.getOptionValue("admin_password", "admin"))
+        val currNumRequests = new LongAdder
+        val maxNumRequests = if (line.hasOption("max_requests"))
+            line.getOptionValue("max_requests").toLong else Long.MaxValue
+        val timeout = if (line.hasOption("timeout"))
+            Duration(line.getOptionValue("timeout").toLong, MILLISECONDS) else 300.seconds
+        val sizeDirective = if (line.hasOption("max_bytes"))
+            withSizeLimit(line.getOptionValue("max_bytes").toLong) else withoutSizeLimit
+        val numCores = if (line.hasOption("num_cpus"))
+            line.getOptionValue("num_cpus").toInt else Runtime.getRuntime.availableProcessors() - 1
+        val gpus = if (line.hasOption("gpus")) line.getOptionValue("gpus") else null
+        val schedulerType = line.getOptionValue("scheduler", "locality-aware")
+
+        // Initialize statistics counters
+        val numTimeouts = new LongAdder
+        val numFailures = new LongAdder
+        val totalTime = new LongAdder
+        val numCompletedPredictions = new LongAdder
+
+        // For now the models need to be loaded every time. TODO: pass the local to serialized models via commandline
+        var models = Map[String, Model]()
+
+        // TODO: Set the scheduler using factory
+        scheduler = SchedulerFactory.getScheduler(schedulerType)
+        val maxMemory = Runtime.getRuntime.maxMemory()  // total memory is just what the JVM has currently allocated
+
+        LOG.info("Total memory allocated to server: " + maxMemory)
+        scheduler.start(numCores, maxMemory, gpus)
+
+        // Define unsecured routes: /predict and /health
+        val unsecuredRoutes = {
+            path("predict") {
+                withoutRequestTimeout {
+                    post {
+                        validate(currNumRequests.longValue() < maxNumRequests, "The prediction server received too many requests. Ignoring the current request.") {
+                            entity(as[PredictionRequestExternal]) { request =>
+                                validate(models.contains(request.name), "The model is not available.") {
+                                    try {
+                                        currNumRequests.increment()
+                                        val start = System.nanoTime()
+                                        val processedRequest = processPredictionRequest(request)
+                                        val deserializationTime = System.nanoTime() - start
+
+                                        val response = Await.result(
+                                            scheduler.enqueue(processedRequest, models(request.name)), timeout)
+                                        totalTime.add(System.nanoTime() - start)
+
+                                        numCompletedPredictions.increment()
+                                        complete(StatusCodes.OK, processPredictionResponse(response, "NOT IMPLEMENTED", deserializationTime))
+                                    } catch {
+                                        case e: scala.concurrent.TimeoutException => {
+                                            numTimeouts.increment()
+                                            complete(StatusCodes.RequestTimeout, "Timeout occured")
+                                        }
+                                        case e: Exception => {
+                                            numFailures.increment()
+                                            e.printStackTrace()
+                                            val msg = "Exception occured while executing the prediction request:"
+                                            complete(StatusCodes.InternalServerError, msg + e.getMessage)
+                                        }
+                                    } finally {
+                                        currNumRequests.decrement()
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+            } ~ path("health") {
+                get {
+                    val stats = "Number of requests (total/completed/timeout/failures):" + currNumRequests.longValue() + "/" + numCompletedPredictions.longValue() + "/"
+                    numTimeouts.longValue() + "/" + numFailures.longValue() + ".\n" +
+                      "Average prediction time:" + ((totalTime.doubleValue() * 1e-6) / numCompletedPredictions.longValue()) + " ms.\n"
+                    complete(StatusCodes.OK, stats)
+                }
+            }
+        }
+
+        // For administration: This can be later extended for supporting multiple users.
+        val securedRoutes = {
+            authenticateBasicAsync(realm = "secure site", userAuthenticate) {
+                user =>
+                    path("shutdown") {
+                        get {
+                            shutdownService(user, scheduler)
+                        }
+                    } ~
+                      path("register-model") {
+                          withoutRequestTimeout {
+                              post {
+                                  entity(as[AddModelRequest]) { request =>
+                                      validate(!models.contains(request.name), "The model is already loaded") {
+                                          try {
+                                              val weightsInfo = processWeights(request.weightsDir)
+                                              val inputs = weightsInfo._1.keys.toArray ++ Array[String](request.inputVarName)
+
+                                              // compile for executor types
+                                              val scriptCpu = conn.prepareScript(
+                                                  request.dml, inputs, Array[String](request.outputVarName))
+                                              var scripts = Map("CPU" -> scriptCpu)
+
+                                              if (gpus != null) {
+                                                  GPUContextPool.AVAILABLE_GPUS = gpus
+                                                  for (ix <- 0 until GPUContextPool.getAvailableCount) {
+                                                      LOG.info("Compiling script for GPU: " + ix)
+                                                      scripts += (s"GPU${ix}" -> conn.prepareScript(
+                                                          request.dml, inputs, Array[String](request.outputVarName),
+                                                          true, true, ix))
+                                                  }
+                                              }
+
+                                              // b = cov(x,y) / var(x)
+                                              // a = mean(y) - b*mean(x)
+                                              val n = max(request.batchSize.length, 1).toDouble
+                                              val x = request.batchSize
+                                              val y = request.memUse
+                                              val mux = x.sum / n
+                                              val muy = y.sum / n
+                                              val vx = (1 / n) * x.map(v => pow(v - mux, 2.0)).sum
+                                              val b = ((1 / n) * (x.map(v => v - mux) zip y.map(v => v - muy)
+                                                ).map(v => v._1 * v._2).sum) * (1 / vx)
+                                              val a = muy - b * mux
+
+                                              // now register the created model
+                                              val model = Model(request.name,
+                                                  scripts,
+                                                  request.inputVarName,
+                                                  request.outputVarName,
+                                                  Duration(request.latencyObjective),
+                                                  weightsInfo._1, (a, b), weightsInfo._2)
+                                              models += (request.name -> model)
+                                              scheduler.addModel(model)
+                                              complete(StatusCodes.OK)
+                                          } catch {
+                                              case e: Exception => {
+                                                  numFailures.increment()
+                                                  e.printStackTrace()
+                                                  complete(StatusCodes.InternalServerError,
+                                                      "Exception occured while trying to add model:" + e.getMessage)
+                                              }
+                                          }
+                                      }
+                                  }
+                              }
+                          }
+                      }
+            }
+        }
+
+        bindingFuture = Http().bindAndHandle(
+            sizeDirective { // Both secured and unsecured routes need to respect the size restriction
+                unsecuredRoutes ~ securedRoutes
+            },
+            line.getOptionValue("ip", "localhost"), line.getOptionValue("port").toInt)
+
+        println(s"Prediction Server online.")
+        while (true) Thread.sleep(100)
+        bindingFuture
+          .flatMap(_.unbind())
+          .onComplete(_ ⇒ system.terminate())
+    }
+
+    def processPredictionResponse(response : PredictionResponse, 
+                                  format : String, 
+                                  deserializationTime: Long) : PredictionResponseExternal = {
+        if (response != null) {
+            val start = System.nanoTime()
+            val dataArray = response.response.getDenseBlockValues
+            val rows = response.response.getNumRows
+            val cols = response.response.getNumColumns
+            val serializationTime = System.nanoTime() - start
+            if (response.statistics != null) {
+                response.statistics.requestDeserializationTime = deserializationTime
+                response.statistics.responseSerializationTime = serializationTime
+            }
+            PredictionResponseExternal(dataArray, rows, cols, response.statistics)
+        } else {
+            PredictionResponseExternal(null, -1, -1, null)
+        }
+    }
+
+    def processWeights(dirname: String) : (Map[String, String], Long) = {
+        val dir = new File(dirname)
+        if (!(dir.exists && dir.isDirectory))
+            throw new Exception("Weight directory: " + dirname + " is invalid")
+
+        val weightsWithSize = dir.listFiles().filter(
+            x => !(x.isDirectory && (x.toString contains "binary"))).map(_.toString).filter(
+            x => (x.slice(x.length-3, x.length) != "mtd") &&
+            !(x contains "_bin.mtx")).
+          map(x => getNameFromPath(x) -> registerWeight(x, dirname)).toMap
+
+        val weightMap = weightsWithSize.map(x => x._1 -> x._2._1)
+        val totalSize = weightsWithSize.map(x => x._2._2).sum
+
+        (weightMap, totalSize)
+    }
+
+    def getNameFromPath(path: String) : String = {
+        path.split("/").last.split("\\.")(0)
+    }
+
+    def registerWeight(path: String, dir: String) : (String, Long) = {
+        val res = convertToBinaryIfNecessary(path, dir)
+        scheduler.modelManager.putWeight(res._2, res._1)
+        (res._2, res._1.getInMemorySize)
+    }
+
+    def convertToBinaryIfNecessary(path: String, dir: String) : (MatrixBlock, String) = {
+        var pathActual = path
+        LOG.info("Reading weight: " + path)
+        val data = conn.readMatrix(path)
+
+        if (!isBinaryFormat(path)) {
+            LOG.info("Converting weight to binary format")
+            data.getMatrixCharacteristics
+            val binPath = dir + "/binary/" + getNameFromPath(path) + ".mtx"
+            DataConverter.writeMatrixToHDFS(data, binPath,
+                OutputInfo.BinaryBlockOutputInfo,
+                new MatrixCharacteristics(data.getNumRows, data.getNumColumns, ConfigurationManager.getBlocksize,
+                    ConfigurationManager.getBlocksize, data.getNonZeros))
+            pathActual = binPath
+        }
+        (data, pathActual)
+    }
+
+    def isBinaryFormat(path: String) : Boolean = {
+        val mtdName = DataExpression.getMTDFileName(path)
+        val mtd = new DataExpression().readMetadataFile(mtdName, false)
+        if (mtd.containsKey("format")) mtd.getString("format") == "binary" else false
+    }
+
+    def processPredictionRequest(request : PredictionRequestExternal) : PredictionRequest = {
+        val mat = new MatrixBlock(request.rows, request.cols, false)
+        mat.init(request.data, request.rows, request.cols)
+        PredictionRequest(mat, request.name, request.rows, System.nanoTime())
+    }
+
+    def processMatrixInput(data : String, rows : Int, cols : Int, format : String) : MatrixBlock = {
+        val result = format match {
+            case "csv" => processTextInput(data, rows, cols, DataExpression.FORMAT_TYPE_VALUE_CSV)
+            case _ => throw new Exception("Only CSV Input currently supported")
+        }
+        result
+    }
+
+    def processTextInput(data : String, rows : Int, cols : Int, format : String) : MatrixBlock = {
+        val is = IOUtilFunctions.toInputStream(data)
+        conn.convertToMatrix(is, rows, cols, format)
+    }
+
+    def userAuthenticate(credentials: akka.http.scaladsl.server.directives.Credentials): Future[Option[String]] = {
+        credentials match {
+            case p@akka.http.scaladsl.server.directives.Credentials.Provided(id) =>
+                Future {
+                    if (userPassword.containsKey(id) && p.verify(userPassword.get(id))) Some(id)
+                    else None
+                }
+            case _ => Future.successful(None)
+        }
+    }
+
+    def shutdownService(user: String, scheduler: Scheduler): StandardRoute = {
+        if (user.equals("admin")) {
+            try {
+                Http().shutdownAllConnectionPools() andThen { case _ => bindingFuture.flatMap(_.unbind()).onComplete(_ ⇒ system.terminate()) }
+                scheduler.shutdown()
+                complete(StatusCodes.OK, "Shutting down the server.")
+            } finally {
+                new Thread(new Runnable {
+                    def run() {
+                        Thread.sleep(100) // wait for 100ms to send reply and then kill the prediction JVM so that we don't wait scala.io.StdIn.readLine()
+                        System.exit(0)
+                    }
+                }).start();
+            }
+        }
+        else {
+            complete(StatusCodes.BadRequest, "Only admin can shutdown the service.")
+        }
+    }
+
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/RLSEstimator.scala b/src/main/scala/org/apache/sysml/api/ml/serving/RLSEstimator.scala
new file mode 100644
index 0000000..03bc8cd
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/RLSEstimator.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+import java.util.concurrent.LinkedBlockingQueue
+
+import breeze.linalg._
+import breeze.numerics._
+import breeze.stats._
+
+class RLSEstimator {
+    val dataQueue = new LinkedBlockingQueue[(Double, Double)]()
+    val chunkSize = 2
+
+    var isInitialized = false
+    var isFinalized = false
+    var Q : DenseMatrix[Double] = _
+    var b : DenseMatrix[Double] = _
+    var n = 0
+    val lda = 0.98
+    val eps = 0.00000001
+    var sigma = -1.0
+
+    def enqueueExample(batchSize: Int, latency: Double) : Unit = {
+        if (!isFinalized) {
+            println("ENQUEUING => " + dataQueue.size())
+            dataQueue.add((batchSize.toDouble, latency))
+            if (dataQueue.size() >= chunkSize)
+                update()
+        }
+    }
+
+    def dequeueExamples() : (DenseMatrix[Double], DenseMatrix[Double]) = {
+        val X = DenseMatrix.zeros[Double](chunkSize,4)
+        val y = DenseMatrix.zeros[Double](chunkSize, 1)
+
+        for (ix <- 0 until chunkSize) {
+            val (x_ex, y_ex) = dataQueue.poll()
+            X(ix,::) := DenseVector[Double](1.0, x_ex, pow(x_ex,2), pow(x_ex,3)).t
+            y(ix,0) = y_ex
+        }
+        (X, y)
+    }
+
+    def update() : Unit = {
+        val s = pow(lda, n)
+        val R = dequeueExamples()
+        val X = R._1
+        val y = R._2
+        if (!isInitialized) {
+            Q = X.t * X
+            b = Q \ (X.t * y)
+            isInitialized = true
+        } else if (s >= eps) {
+            val Q_new = Q + (X.t * X)
+            val S = pow(lda, n) * DenseMatrix.eye[Double](chunkSize)
+            val K = inv(Q_new) * (X.t * S) // Kalman filter gain
+            val V = y - (X * b) // Innovations
+            b :+= K * V
+            Q = Q_new
+        } else {
+            isFinalized = true
+            dataQueue.clear()
+        }
+        sigma = variance(y - (X*b))
+        n += 1
+    }
+
+    def predict(batchSize: Int) : (Double,Double) = {
+        val x = DenseMatrix(1.0, batchSize, pow(batchSize,2), pow(batchSize,3)).reshape(1,4)
+        val y_hat = x*b
+        (max(y_hat(0,0), 0.0), sigma)
+    }
+
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/Scheduler.scala b/src/main/scala/org/apache/sysml/api/ml/serving/Scheduler.scala
new file mode 100644
index 0000000..39e85c0
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/Scheduler.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+import scala.concurrent.Future
+import scala.concurrent.duration._
+import java.util.concurrent._
+import java.util.List
+
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext
+import org.apache.commons.logging.Log
+import org.apache.commons.logging.LogFactory
+
+import scala.concurrent.ExecutionContext
+
+case class SchedulingRequest(request: PredictionRequest,
+                             model: Model,
+                             latch: CountDownLatch,
+                             receivedTime: Long,
+                             var response: PredictionResponse = null,
+                             statistics: RequestStatistics = null,
+                             var memUse: Long = 0)
+
+trait Scheduler {
+    val LOG: Log = LogFactory.getLog(classOf[Scheduler].getName)
+    var executorService: ExecutorService = _
+    protected var _statistics = true
+    implicit val ec = ExecutionContext.global
+    var executorTypes = Array[String]()
+    var modelManager = ReferenceCountedModelManager
+
+    def start(numCores: Int, cpuMemoryBudgetInBytes: Long, gpus: String): Unit = {
+        LOG.info(s"Starting Scheduler with ${numCores} CPUs and ${gpus} GPUs")
+        var numGpus = 0
+        var gCtxs: List[GPUContext] = null
+        if (gpus != null) {
+            GPUContextPool.AVAILABLE_GPUS = gpus
+            gCtxs = GPUContextPool.getAllGPUContexts
+            numGpus = gCtxs.size
+        }
+
+        executorService = Executors.newFixedThreadPool(numCores + numGpus)
+        modelManager.setAvailableMemory((cpuMemoryBudgetInBytes*0.80).toLong)
+
+        if (numCores > 0)
+            executorTypes :+= "CPU"
+        if (numGpus > 0)
+            executorTypes :+= "GPU"
+
+        LOG.debug("STARTING SCHEDULER WITH: " + numCores + " CPU => " + numGpus + " GPUS")
+        for (i <- 0 until numCores) {
+            val exec = new JmlcExecutor(this, "CPU", "CPU" + i, null)
+            executorQueues.put(exec, new BatchQueue("CPU", "CPU" + i))
+            executorService.submit(exec)
+        }
+        for (i <- 0 until numGpus) {
+            val exec = new JmlcExecutor(this, "GPU","GPU" + i, gCtxs.get(i))
+            executorQueues.put(exec, new BatchQueue("GPU", "GPU" + i))
+            executorService.submit(exec)
+        }
+    }
+
+    def shutdown(): Unit = {
+        executorService.shutdown()
+    }
+
+    def schedule(executor: JmlcExecutor): Array[SchedulingRequest]
+
+    /**
+      * Registers a model with this scheduler. This should be called before enqueueing requests
+      * @param model Model object to be registered
+      */
+    def addModel(model: Model): Unit = {
+        modelQueues.putIfAbsent(model.name, new LinkedBlockingDeque[SchedulingRequest]())
+        latencyObjectives.putIfAbsent(model.name, model.latencyObjective)
+        modelManager.put(model)
+    }
+
+    /**
+      * Sets a flag indicating if detailed statistics should be gathered which profile the time spent
+      * in various stages of the execution pipeline
+      * @param flag Boolean flag indicating whether statistics should be gathered
+      */
+    def setStatistics(flag: Boolean): Unit = { _statistics = flag }
+
+    def timeout: Duration = 300.seconds
+
+    /**
+      * Method which is used to update scheduler state of execution of a batch. If necessary
+      * objects implementing the Scheduler trait should override this method and implement any logic needed
+      * to post-process execution after a batch
+      *
+      * @param model String indicating the name of the model which was just executed
+      * @param latency A measure of latency for this batch
+      * @param batchSize The number of examples in the batch
+      * @param execType The device type on which the batch was executed
+      */
+    def onCompleteCallback(model: String, latency: Double, batchSize: Int, execType: String, execTime: Long) : Unit
+
+    val requestQueue = new LinkedBlockingDeque[SchedulingRequest]()
+    val globalSchedulingQueues = new ConcurrentHashMap[String, BatchQueue]()
+    var modelQueues = new ConcurrentHashMap[String, BlockingQueue[SchedulingRequest]]()
+    var executorQueues = new ConcurrentHashMap[JmlcExecutor, BatchQueue]()
+    val dummyResponse = PredictionResponse(null, -1, null)
+    val latencyObjectives = new ConcurrentHashMap[String, Duration]()
+    var counter = 0
+
+    /**
+      * Enqueues a request for processing. The scheduler will read from these queues to determine which
+      * models to execute next
+      * @param request A PredictionRequest object containing the data for which a prediction is desired
+      * @param model The model object for which prediction
+      * @return
+      */
+    private[serving] def enqueue(request: PredictionRequest, model: Model): Future[PredictionResponse]
+}
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/SchedulerFactory.scala b/src/main/scala/org/apache/sysml/api/ml/serving/SchedulerFactory.scala
new file mode 100644
index 0000000..6fa6007
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/SchedulerFactory.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+object SchedulerFactory {
+  def getScheduler(schedulerType: String) : Scheduler = {
+    schedulerType match {
+      case "non-batching"   => NonBatchingScheduler
+      case "basic-batching" => BasicBatchingScheduler
+      case "locality-aware" => LocalityAwareScheduler
+    }
+  }
+}
\ No newline at end of file