This is an automated email from the ASF dual-hosted git repository.

niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new 863c9d5  [SYSTEMML-2525] Initial implementation of RESTful model 
serving system
863c9d5 is described below

commit 863c9d5cb1752b0e50140f5c6673968b57c2f9d0
Author: Anthony Thomas <ahtho...@eng.ucsd.edu>
AuthorDate: Fri Mar 29 10:27:54 2019 -0700

    [SYSTEMML-2525] Initial implementation of RESTful model serving system
    
    - The current implementation extends JMLC's readMatrix and GPUContext API.
    - The serving system is implemented in Scala using Akka and is available in 
the org.apache.sysml.api.ml.serving.
    - Minor cleanup and refactoring required before it's ready to be used by 
the general public will be done in subsequent commits.
    - It still remains unclear whether CUDA and Serving code should be included 
in future standalone releases. If yes, it will greatly simplify the deployment, 
else the user will have to build standalone jar before deployment.
    - The serving system can be started by:
    ```
    mvn -Djcuda.scope=compile -Dserving.scope=compile package -P standalone-jar
    java -jar systemml-*-standalone.jar 
org.apache.sysml.api.ml.serving.PredictionService -port 8099 -scheduler 
scheduler -admin_password admin
    ```
    - The model can registered using http://localhost:8099/register-model and 
user can invoke prediction using http://localhost:8099/predict service.
    
    Closes #860.
---
 .travis.yml                                        |   5 +-
 pom.xml                                            |  31 ++
 .../java/org/apache/sysml/api/jmlc/Connection.java |  51 +++
 .../org/apache/sysml/api/jmlc/PreparedScript.java  |  18 +
 .../org/apache/sysml/parser/DataExpression.java    |   1 +
 .../runtime/controlprogram/LocalVariableMap.java   |   4 +
 .../org/apache/sysml/utils/PersistentLRUCache.java |  97 ++--
 .../api/ml/serving/BasicBatchingScheduler.scala    |  93 ++++
 .../sysml/api/ml/serving/BatchingScheduler.scala   |  99 +++++
 .../sysml/api/ml/serving/BatchingUtils.scala       |  57 +++
 .../org/apache/sysml/api/ml/serving/Executor.scala | 155 +++++++
 .../api/ml/serving/LocalityAwareScheduler.scala    | 218 +++++++++
 .../apache/sysml/api/ml/serving/ModelManager.scala | 176 ++++++++
 .../api/ml/serving/NonBatchingScheduler.scala      |  69 +++
 .../sysml/api/ml/serving/PredictionService.scala   | 490 +++++++++++++++++++++
 .../apache/sysml/api/ml/serving/RLSEstimator.scala |  91 ++++
 .../apache/sysml/api/ml/serving/Scheduler.scala    | 133 ++++++
 .../sysml/api/ml/serving/SchedulerFactory.scala    |  29 ++
 18 files changed, 1756 insertions(+), 61 deletions(-)

diff --git a/.travis.yml b/.travis.yml
index a0c308b..3ce9d06 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -46,7 +46,8 @@ before_script:
 
 script:
 #  - mvn clean verify jacoco:report coveralls:report
-  - mvn clean verify
+# The -q parameter tells mvn not to display anything other than ERROR level 
log messages. This is required because travis kills the job after the log 
length exceeds its maximum log length (usually 4 MB).
+  - mvn -q clean verify
 
 after_success:
-#  -  mvn test jacoco:report coveralls:report
\ No newline at end of file
+#  -  mvn test jacoco:report coveralls:report
diff --git a/pom.xml b/pom.xml
index ad74276..4b5dd29 100644
--- a/pom.xml
+++ b/pom.xml
@@ -72,6 +72,7 @@
                <maven.build.timestamp.format>yyyy-MM-dd HH:mm:ss 
z</maven.build.timestamp.format>
                <enableGPU>false</enableGPU>
                <jcuda.scope>provided</jcuda.scope>
+               <serving.scope>provided</serving.scope>
                <jcuda.version>0.9.0d</jcuda.version>
                <!-- OS-specific JVM arguments for running integration tests -->
                <integrationTestExtraJVMArgs />
@@ -1259,6 +1260,36 @@
                        <version>3.2.0</version>
                </dependency>
                <dependency>
+                   <groupId>com.typesafe.akka</groupId>
+                   <artifactId>akka-http_2.11</artifactId>
+                   <version>10.1.3</version>
+                   <scope>${serving.scope}</scope>
+               </dependency>
+               <dependency>
+                   <groupId>com.typesafe.akka</groupId>
+                   <artifactId>akka-actor_2.11</artifactId>
+                   <version>2.5.14</version>
+                       <scope>${serving.scope}</scope>
+               </dependency>
+               <dependency>
+                   <groupId>com.typesafe.akka</groupId>
+                   <artifactId>akka-stream_2.11</artifactId>
+                   <version>2.5.14</version>
+                       <scope>${serving.scope}</scope>
+               </dependency>
+               <dependency>
+                   <groupId>com.typesafe</groupId>
+                   <artifactId>config</artifactId>
+                   <version>1.2.0</version>
+                       <scope>${serving.scope}</scope>
+               </dependency>
+               <dependency>
+                   <groupId>com.typesafe.akka</groupId>
+                   
<artifactId>akka-http-spray-json-experimental_2.11</artifactId>
+                   <version>2.4.11.2</version>
+                       <scope>${serving.scope}</scope>
+               </dependency>
+               <dependency>
                        <groupId>org.jcuda</groupId>
                        <artifactId>jcuda</artifactId>
                        <version>${jcuda.version}</version>
diff --git a/src/main/java/org/apache/sysml/api/jmlc/Connection.java 
b/src/main/java/org/apache/sysml/api/jmlc/Connection.java
index 53b7d04..29df4c0 100644
--- a/src/main/java/org/apache/sysml/api/jmlc/Connection.java
+++ b/src/main/java/org/apache/sysml/api/jmlc/Connection.java
@@ -370,6 +370,57 @@ public class Connection implements Closeable
        // Read matrices
        ////////////////////////////////////////////
        
+       public MatrixBlock readMatrix(String fname) throws IOException {
+               try {
+                       String fnamemtd = DataExpression.getMTDFileName(fname);
+                       JSONObject jmtd = new 
DataExpression().readMetadataFile(fnamemtd, false);
+
+                       //parse json meta data
+                       long rows = jmtd.getLong(DataExpression.READROWPARAM);
+                       long cols = jmtd.getLong(DataExpression.READCOLPARAM);
+                       int brlen = 
jmtd.containsKey(DataExpression.ROWBLOCKCOUNTPARAM)?
+                                       
jmtd.getInt(DataExpression.ROWBLOCKCOUNTPARAM) : -1;
+                       int bclen = 
jmtd.containsKey(DataExpression.COLUMNBLOCKCOUNTPARAM)?
+                                       
jmtd.getInt(DataExpression.COLUMNBLOCKCOUNTPARAM) : -1;
+                       long nnz = 
jmtd.containsKey(DataExpression.READNNZPARAM)?
+                                       
jmtd.getLong(DataExpression.READNNZPARAM) : -1;
+                       String format = 
jmtd.getString(DataExpression.FORMAT_TYPE);
+                       InputInfo iinfo = 
InputInfo.stringExternalToInputInfo(format);
+                       return readMatrix(fname, iinfo, rows, cols, brlen, 
bclen, nnz);
+               } catch (Exception ex) {
+                       throw new IOException(ex);
+               }
+       }
+       
+       /**
+        * Reads an input matrix in arbitrary format from HDFS into a dense 
double array.
+        * NOTE: this call currently only supports default configurations for 
CSV.
+        *
+        * @param fname the filename of the input matrix
+        * @param iinfo InputInfo object
+        * @param rows number of rows in the matrix
+        * @param cols number of columns in the matrix
+        * @param brlen number of rows per block
+        * @param bclen number of columns per block
+        * @param nnz number of non-zero values, -1 indicates unknown
+        * @return matrix as a two-dimensional double array
+        * @throws IOException if IOException occurs
+        */
+       public MatrixBlock readMatrix(String fname, InputInfo iinfo, long rows, 
long cols, int brlen, int bclen, long nnz)
+                       throws IOException
+       {
+               setLocalConfigs();
+
+               try {
+                       MatrixReader reader = 
MatrixReaderFactory.createMatrixReader(iinfo);
+                       return reader.readMatrixFromHDFS(fname, rows, cols, 
brlen, bclen, nnz);
+
+               }
+               catch(Exception ex) {
+                       throw new IOException(ex);
+               }
+       }
+       
        /**
         * Reads an input matrix in arbitrary format from HDFS into a dense 
double array.
         * NOTE: this call currently only supports default configurations for 
CSV.
diff --git a/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java 
b/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java
index 701af30..5926bcc 100644
--- a/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java
+++ b/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java
@@ -75,6 +75,11 @@ public class PreparedScript implements ConfigurableAPI
        private final HashSet<String> _outVarnames;
        private final HashMap<String,Data> _inVarReuse;
        
+       private String name = "";
+       public void setName(String name) {
+               this.name = name;
+       }
+       
        //internal state (reused)
        private final Program _prog;
        private final LocalVariableMap _vars;
@@ -131,6 +136,19 @@ public class PreparedScript implements ConfigurableAPI
                _cconf = cconf;
        }
        
+       public void clearPinnedData() {
+               this._inVarReuse.clear();
+       }
+       
+       public boolean hasPinnedData() { 
+               return _inVarReuse.keySet().size() > 0; 
+       }
+       
+       public void setGpuContext(GPUContext gCtx) { 
+               this._gpuCtx.set(0, gCtx); 
+       }
+       
+       
        /**
         * Sets a boolean flag indicating if runtime statistics should be 
gathered
         * Same behavior as in "MLContext.setStatistics()"
diff --git a/src/main/java/org/apache/sysml/parser/DataExpression.java 
b/src/main/java/org/apache/sysml/parser/DataExpression.java
index 44f368e..7b64922 100644
--- a/src/main/java/org/apache/sysml/parser/DataExpression.java
+++ b/src/main/java/org/apache/sysml/parser/DataExpression.java
@@ -44,6 +44,7 @@ import org.apache.sysml.runtime.util.UtilFunctions;
 import org.apache.sysml.utils.JSONHelper;
 import org.apache.wink.json4j.JSONArray;
 import org.apache.wink.json4j.JSONObject;
+import org.apache.sysml.parser.Expression.DataOp;
 
 
 public class DataExpression extends DataIdentifier 
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
index cf9e79a..bb06849 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
@@ -131,6 +131,10 @@ public class LocalVariableMap implements Cloneable
                        put(kv.getKey(), kv.getValue());
                }
        }
+       
+       public void putAll(LocalVariableMap vars) { 
+               putAll(vars.localMap); 
+       }
 
        public Data remove( String name ) {
                Data ret = localMap.remove( name );
diff --git a/src/main/java/org/apache/sysml/utils/PersistentLRUCache.java 
b/src/main/java/org/apache/sysml/utils/PersistentLRUCache.java
index d9d9337..24e685b 100644
--- a/src/main/java/org/apache/sysml/utils/PersistentLRUCache.java
+++ b/src/main/java/org/apache/sysml/utils/PersistentLRUCache.java
@@ -86,7 +86,7 @@ public class PersistentLRUCache extends LinkedHashMap<String, 
ValueWrapper> {
        private String _prefixFilePath;
        final AtomicLong _currentNumBytes = new AtomicLong();
        private final long _maxNumBytes;
-       private static final Random _rand = new Random();
+       Random _rand = new Random();
        boolean isInReadOnlyMode;
        HashSet<String> persistedKeys = new HashSet<>();
        
@@ -101,9 +101,6 @@ public class PersistentLRUCache extends 
LinkedHashMap<String, ValueWrapper> {
                for(long i = 0; i < numIter; ++i) {
                        LOG.debug("Putting a double array of size 50MB.");
                        cache.put("file_" + i, new double[numDoubleIn50MB]);
-                       try {
-                               Thread.sleep(100);
-                       } catch (InterruptedException e) {}
                }
                cache.clear();
        }
@@ -130,13 +127,13 @@ public class PersistentLRUCache extends 
LinkedHashMap<String, ValueWrapper> {
                _prefixFilePath = tmp.getAbsolutePath();
        }
        public ValueWrapper put(String key, double[] value) throws 
FileNotFoundException, IOException {
-               return putImplm(key, new ValueWrapper(new DataWrapper(key, 
value, this), isInReadOnlyMode), value.length*Double.BYTES);
+               return putImplm(key, new ValueWrapper(new DataWrapper(key, 
value, this)), value.length*Double.BYTES);
        }
        public ValueWrapper put(String key, float[] value) throws 
FileNotFoundException, IOException {
-               return putImplm(key, new ValueWrapper(new DataWrapper(key, 
value, this), isInReadOnlyMode), value.length*Float.BYTES);
+               return putImplm(key, new ValueWrapper(new DataWrapper(key, 
value, this)), value.length*Float.BYTES);
        }
        public ValueWrapper put(String key, MatrixBlock value) throws 
FileNotFoundException, IOException {
-               return putImplm(key, new ValueWrapper(new DataWrapper(key, 
value, this), isInReadOnlyMode), value.getInMemorySize());
+               return putImplm(key, new ValueWrapper(new DataWrapper(key, 
value, this)), value.getInMemorySize());
        }
        
        private ValueWrapper putImplm(String key, ValueWrapper value, long 
sizeInBytes) throws FileNotFoundException, IOException {
@@ -209,7 +206,7 @@ public class PersistentLRUCache extends 
LinkedHashMap<String, ValueWrapper> {
     }
        
        float [] tmp = new float[0];
-       static String dummyKey = "RAND_KEY_" + Math.abs(_rand.nextLong()) + "_" 
+ Math.abs(_rand.nextLong());
+       String dummyKey = "RAND_KEY_" + Math.abs(_rand.nextLong()) + "_" + 
Math.abs(_rand.nextLong());
        void ensureCapacity(long newNumBytes) throws FileNotFoundException, 
IOException {
                if(newNumBytes > _maxNumBytes) {
                        throw new DMLRuntimeException("Exceeds maximum 
capacity. Cannot put a value of size " + newNumBytes + 
@@ -220,7 +217,7 @@ public class PersistentLRUCache extends 
LinkedHashMap<String, ValueWrapper> {
                        synchronized(this) {
                                if(LOG.isDebugEnabled())
                                        LOG.debug("The required capacity (" + 
newCapacity + ") is greater than max capacity:" + _maxNumBytes);
-                               ValueWrapper dummyValue = new ValueWrapper(new 
DataWrapper(dummyKey, tmp, this), isInReadOnlyMode);
+                               ValueWrapper dummyValue = new ValueWrapper(new 
DataWrapper(dummyKey, tmp, this));
                                int maxIter = size();
                                while(_currentNumBytes.get() > _maxNumBytes && 
maxIter > 0) {
                                        super.put(dummyKey, dummyValue); // 
This will invoke removeEldestEntry, which will set _eldest
@@ -351,13 +348,17 @@ class DataWrapper {
                _mo = value;
                _cache = cache;
        }
+       @Override
+       protected void finalize() throws Throwable {
+               super.finalize();
+               write(true);
+       }
        
-       public synchronized void write(boolean forceAggresiveWrites) throws 
FileNotFoundException, IOException {
-               if(_key.equals(PersistentLRUCache.dummyKey))
+       public synchronized void write(boolean isBeingGarbageCollected) throws 
FileNotFoundException, IOException {
+               if(_key.equals(_cache.dummyKey))
                        return;
-               
-               // Prepare for writing
                _cache.makeRecent(_key); // Make it recent.
+               
                if(_dArr != null || _fArr != null || _mb != null || _mo != 
null) {
                        _cache._currentNumBytes.addAndGet(-getSize());
                }
@@ -365,16 +366,14 @@ class DataWrapper {
                if(!_cache.isInReadOnlyMode) {
                        String debugSuffix = null;
                        if(PersistentLRUCache.LOG.isDebugEnabled()) {
-                               if(forceAggresiveWrites)
-                                       debugSuffix = " (aggressively 
written).";
+                               if(isBeingGarbageCollected)
+                                       debugSuffix = " (is being garbage 
collected).";
                                else
                                        debugSuffix = " (capacity exceeded).";
                        }
                        
                        if(_dArr != null) {
-                               File file = new File(_cache.getFilePath(_key));
-                               file.deleteOnExit();
-                               try (ObjectOutputStream os = new 
ObjectOutputStream(new FileOutputStream(file))) {
+                               try (ObjectOutputStream os = new 
ObjectOutputStream(new FileOutputStream(_cache.getFilePath(_key)))) {
                                        os.writeInt(_dArr.length);
                                        for(int i = 0; i < _dArr.length; i++) {
                                                os.writeDouble(_dArr[i]);
@@ -385,9 +384,7 @@ class DataWrapper {
                                        PersistentLRUCache.LOG.debug("Writing 
value (double[] of size " + getSize() + " bytes) for the key " + _key + " to 
disk" + debugSuffix);
                        }
                        else if(_fArr != null) {
-                               File file = new File(_cache.getFilePath(_key));
-                               file.deleteOnExit();
-                               try (ObjectOutputStream os = new 
ObjectOutputStream(new FileOutputStream(file))) {
+                               try (ObjectOutputStream os = new 
ObjectOutputStream(new FileOutputStream(_cache.getFilePath(_key)))) {
                                        os.writeInt(_fArr.length);
                                        for(int i = 0; i < _fArr.length; i++) {
                                                os.writeFloat(_fArr[i]);
@@ -398,13 +395,12 @@ class DataWrapper {
                                        PersistentLRUCache.LOG.debug("Writing 
value (float[] of size " + getSize() + " bytes) for the key " + _key + " to 
disk" + debugSuffix);
                        }
                        else if(_mb != null) {
-                               File file = new File(_cache.getFilePath(_key));
-                               file.deleteOnExit();
-                               try(FastBufferedDataOutputStream os = new 
FastBufferedDataOutputStream(new ObjectOutputStream(new 
FileOutputStream(file)))) {
+                               try(FastBufferedDataOutputStream os = new 
FastBufferedDataOutputStream(new ObjectOutputStream(new 
FileOutputStream(_cache.getFilePath(_key))))) {
                                        os.writeLong(_mb.getInMemorySize());
                                        _mb.write(os);
                                }
                                _cache.persistedKeys.add(_key);
+                               System.err.println("Writing value (MatrixBlock 
of size " + getSize() + " bytes) for the key " + _key + " to disk" + 
debugSuffix);
                                if(PersistentLRUCache.LOG.isDebugEnabled())
                                        PersistentLRUCache.LOG.debug("Writing 
value (MatrixBlock of size " + getSize() + " bytes) for the key " + _key + " to 
disk" + debugSuffix);
                        }
@@ -513,63 +509,46 @@ class DataWrapper {
 // Internal helper class
 class ValueWrapper {
        final Object _lock;
-       final boolean _isInReadOnlyMode;
-       private SoftReference<DataWrapper> _softRef;
+       private SoftReference<DataWrapper> _ref;
        long _rlen;
        long _clen;
        long _nnz;
        
-       ValueWrapper(DataWrapper data, boolean isInReadOnlyMode) {
+       ValueWrapper(DataWrapper _data) {
                _lock = new Object();
-               _isInReadOnlyMode = isInReadOnlyMode;
-               boolean isDummyValue = (data._key == 
PersistentLRUCache.dummyKey);
-               if(!_isInReadOnlyMode && !isDummyValue) {
-                       // Aggressive write to disk when the cache is used in 
the write-mode.
-                       // This avoids the need to depend on finalize to 
perform writing.
-                       Thread t = new Thread() {
-                           public void run() {
-                               try {
-                                       data.write(true);
-                                       } catch (IOException e) {
-                                               throw new 
DMLRuntimeException("Error occured while aggressively writing the value to 
disk.", e);
-                                       }
-                           }
-                       };
-                       t.start();
-               }
-               _softRef = new SoftReference<>(data);
-               if(data._mb != null) {
-                       _rlen = data._mb.getNumRows();
-                       _clen = data._mb.getNumColumns();
-                       _nnz = data._mb.getNonZeros();
+               _ref = new SoftReference<>(_data);
+               if(_data._mb != null) {
+                       _rlen = _data._mb.getNumRows();
+                       _clen = _data._mb.getNumColumns();
+                       _nnz = _data._mb.getNonZeros();
                }
        }
-       void update(DataWrapper data) {
-               _softRef = new SoftReference<>(data);
-               if(data._mb != null) {
-                       _rlen = data._mb.getNumRows();
-                       _clen = data._mb.getNumColumns();
-                       _nnz = data._mb.getNonZeros();
+       void update(DataWrapper _data) {
+               _ref = new SoftReference<>(_data);
+               if(_data._mb != null) {
+                       _rlen = _data._mb.getNumRows();
+                       _clen = _data._mb.getNumColumns();
+                       _nnz = _data._mb.getNonZeros();
                }
        }
        boolean isAvailable() {
-               DataWrapper data = _softRef.get();
+               DataWrapper data = _ref.get();
                return data != null && data.isAvailable();
        }
        DataWrapper get() {
-               return _softRef.get();
+               return _ref.get();
        }
        long getSize() {
-               DataWrapper data = _softRef.get();
+               DataWrapper data = _ref.get();
                if(data != null) 
                        return data.getSize();
                else
                        return 0;
        }
        void remove() {
-               DataWrapper data = _softRef.get();
+               DataWrapper data = _ref.get();
                if(data != null) {
                        data.remove();
                }
        }
-}
+}
\ No newline at end of file
diff --git 
a/src/main/scala/org/apache/sysml/api/ml/serving/BasicBatchingScheduler.scala 
b/src/main/scala/org/apache/sysml/api/ml/serving/BasicBatchingScheduler.scala
new file mode 100644
index 0000000..6f19cce
--- /dev/null
+++ 
b/src/main/scala/org/apache/sysml/api/ml/serving/BasicBatchingScheduler.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+import java.util.concurrent.{ConcurrentHashMap, CountDownLatch}
+
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
+import scala.math.min
+
+object BasicBatchingScheduler extends BatchingScheduler {
+
+    override def start(numCores: Int, cpuMemoryBudgetInBytes: Long, gpus: 
String): Unit = {
+        LOG.info(s"Starting Basic Batching Scheduler with: ${numCores} CPUs 
and ${gpus} GPUs")
+        super.start(numCores, cpuMemoryBudgetInBytes, gpus)
+    }
+
+    /**
+      * Returns a list of requests to execute. If the list contains more than 
one element, they will be batched
+      * by the executor. Returns an empty list when there are no models to be 
scheduled.
+      * @param executor an Executor instance
+      * @return a list of model requests to process
+      */
+    override def schedule(executor: JmlcExecutor) : Array[SchedulingRequest] = 
{
+        var ret = Array[SchedulingRequest]()
+        val execType = executor.getExecType
+        dummyResponse.synchronized {
+            val schedulableModels = getSchedulableModels(execType)
+            if (schedulableModels.nonEmpty) {
+                val (nextModel, nextBatchSize) = 
getNextModelAndBatchSize(schedulableModels, execType)
+                for (_ <- 0 until nextBatchSize) {
+                    val next = modelQueues.get(nextModel).poll()
+                    assert(next != null, "Something is wrong. Next model 
should not be null")
+                    ret :+= next
+                }
+            }
+        }
+        ret
+    }
+
+    /**
+      * Helper method which gets the next model to schedule and the optimal 
batchsize
+      * @param models A list of models to schedule
+      * @return The model to schedule next
+      */
+    def getNextModelAndBatchSize(models : Iterable[String], execType: String) 
: (String, Int) = {
+        val nextModel = models.map(m =>
+            (getOptimalBatchSize(m, execType)*getExpectedExecutionTime(m), 
m)).minBy(x => x._1)._2
+
+        val nextBatchSize = min(modelQueues.get(nextModel).size(),
+            getOptimalBatchSize(nextModel, execType))
+        (nextModel, nextBatchSize)
+    }
+
+    /**
+      * Enqueues a request for processing. The scheduler will read from these 
queues to determine which
+      * models to execute next
+      * @param request A PredictionRequest object containing the data for 
which a prediction is desired
+      * @param model The model object for which prediction
+      * @return
+      */
+    override private[serving] def enqueue(request: PredictionRequest, model: 
Model): Future[PredictionResponse] = Future {
+        val statistics = if (_statistics) RequestStatistics() else null
+        val schedulingRequest = SchedulingRequest(
+            request, model, new CountDownLatch(1), System.nanoTime(), null, 
statistics)
+        statistics.queueSize = modelQueues.get(model.name).size
+        modelQueues.get(model.name).add(schedulingRequest)
+        counter += 1
+        try {
+            schedulingRequest.latch.await(timeout.length, timeout.unit)
+            schedulingRequest.response
+        } catch {
+            case e : scala.concurrent.TimeoutException => dummyResponse
+        }
+    }
+
+}
\ No newline at end of file
diff --git 
a/src/main/scala/org/apache/sysml/api/ml/serving/BatchingScheduler.scala 
b/src/main/scala/org/apache/sysml/api/ml/serving/BatchingScheduler.scala
new file mode 100644
index 0000000..e62f4ef
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/BatchingScheduler.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.LongAdder
+
+import scala.math.{floor, max}
+
+trait BatchingScheduler extends Scheduler {
+
+    val modelBatchSizes = new ConcurrentHashMap[String, 
ConcurrentHashMap[String,Int]]()
+    val expectedExecutionTimes = new ConcurrentHashMap[String, (LongAdder, 
LongAdder)]()
+
+    def getOptimalBatchSize(model : String, execType: String) : Int = {
+        modelBatchSizes.putIfAbsent(execType, new 
ConcurrentHashMap[String,Int]())
+        modelBatchSizes.get(execType).putIfAbsent(model, 2)
+        modelBatchSizes.get(execType).get(model)
+    }
+
+    override def onCompleteCallback(model: String,
+                                    latency: Double,
+                                    batchSize: Int,
+                                    execType: String,
+                                    execTime: Long): Unit = {
+        if (batchSize > 1) {
+            val latencyObjective = latencyObjectives.get(model)
+            val prevSize = modelBatchSizes.get(execType).get(model)
+            val decreaseSize = if (prevSize > 10) max(floor(prevSize * 
0.90).toInt, 1) else prevSize - 1
+            modelBatchSizes.get(execType).put(model,
+                if (latency < latencyObjective.toNanos) prevSize + 1 else 
decreaseSize)
+
+            // update expected execution times. For now we just assume this is 
a simple average
+            val execTimeData = expectedExecutionTimes.get(model)
+            execTimeData._1.add(execTime / batchSize)
+            execTimeData._2.increment()
+        }
+    }
+
+    def getExpectedExecutionTime(model: String) : Long = {
+        expectedExecutionTimes.putIfAbsent(model, (new LongAdder(), new 
LongAdder()))
+        val execTime = expectedExecutionTimes.get(model)
+        val totalNumRequests = execTime._2.longValue()
+        if  (totalNumRequests > 0) execTime._1.longValue() / 
execTime._2.longValue() else 0
+    }
+
+    /**
+      * Gets a list of models that are eligible to be run. A model is eligible 
to be run if it
+      * has a greater number of requests enqueued than its optimal batch size.
+      * @return A list of models which may be scheduled
+      */
+    def getSchedulableModels(execType: String) : Set[String] = {
+        var batchableModels = Set[String]()
+        var shortFuse = Set[String]()
+        val keyIterator = modelQueues.keys()
+        while (keyIterator.hasMoreElements) {
+            val name = keyIterator.nextElement()
+            val qsize = modelQueues.get(name).size()
+            if (qsize > 0) {
+                val nextRequest = modelQueues.get(name).peek()
+                assert(nextRequest != null, "Something is wrong. Next request 
should not be null")
+
+                if (checkShortFuse(nextRequest, qsize)) {
+                    LOG.info("Model: " + name + " is near violating threshold. 
Scheduling immediately.")
+                    shortFuse += name
+                }
+
+                if (qsize >= getOptimalBatchSize(name, execType)) {
+                    batchableModels += name
+                }
+            }
+        }
+
+        if (shortFuse.nonEmpty) shortFuse else batchableModels
+    }
+
+    /**
+      * Returns a boolean value if it would violate the latency threshold to 
execute the current number of models
+      */
+    def checkShortFuse(request: SchedulingRequest, numRequests: Int) : Boolean 
= {
+        val elapsed = System.nanoTime() - request.receivedTime
+        (elapsed + 
1.1*numRequests*getExpectedExecutionTime(request.model.name)) > 
request.model.latencyObjective.toNanos
+    }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/BatchingUtils.scala 
b/src/main/scala/org/apache/sysml/api/ml/serving/BatchingUtils.scala
new file mode 100644
index 0000000..ca28e7f
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/BatchingUtils.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+import org.apache.sysml.runtime.matrix.data.MatrixBlock
+
+object BatchingUtils {
+        def batchRequests(requests: Array[SchedulingRequest]) : MatrixBlock = {
+            if (requests.length == 1) {
+                return requests(0).request.data
+            }
+            val ncol = requests(0).request.data.getNumColumns
+            val res = new MatrixBlock(requests.length, ncol, 
-1).allocateDenseBlock()
+            val doubles = res.getDenseBlockValues
+            var start = 0
+            for (req <- requests) {
+                System.arraycopy(req.request.data.getDenseBlockValues, 0, 
doubles, start, ncol)
+                start += ncol
+            }
+            res.setNonZeros(-1)
+            res
+        }
+
+        def unbatchRequests(requests: Array[SchedulingRequest],
+                            batchedResults: MatrixBlock) : 
Array[PredictionResponse] = {
+            var responses = Array[PredictionResponse]()
+            val start = 0
+            for (req <- requests) {
+                val unbatchStart = System.nanoTime()
+                val resp = PredictionResponse(batchedResults.slice(
+                    start, (start + req.request.requestSize)-1), 
+                    batchedResults.getNumRows, req.statistics)
+                val unbatchingTime = System.nanoTime() - unbatchStart
+                if (req.statistics != null)
+                    req.statistics.unbatchingTime = unbatchingTime
+
+                responses :+= resp
+            }
+
+            responses
+        }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/Executor.scala 
b/src/main/scala/org/apache/sysml/api/ml/serving/Executor.scala
new file mode 100644
index 0000000..c353e07
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/Executor.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.ml.serving
+import java.util.concurrent.PriorityBlockingQueue
+import java.util.concurrent.atomic.LongAdder
+
+import org.apache.commons.logging.{Log, LogFactory}
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext
+
+
+case class Batch(size: Int, expectedTime: Long, priority: Double, modelName: 
String) extends Comparable[Batch] {
+    override def compareTo(that: Batch): Int = {
+        this.priority.compareTo(that.priority)
+    }
+}
+
+class BatchQueue(execType: String, name: String) extends 
PriorityBlockingQueue[Batch] {
+    val LOG: Log = LogFactory.getLog(classOf[BatchQueue].getName)
+    private val expectedExecutionTime = new LongAdder()
+    private var prevFirstRequest= Map[String, SchedulingRequest]()
+
+    def getName : String = { name }
+
+    def updatePrevRequest(name: String, request: SchedulingRequest) : Unit = {
+        prevFirstRequest += (name -> request)
+    }
+
+    def getPrevRequest(name: String) : SchedulingRequest = { 
prevFirstRequest.getOrElse(name, null) }
+
+    def enqueue(batch: Batch) : Unit = {
+        LOG.debug("Enqueuing onto: " + getName)
+        synchronized {
+            this.add(batch)
+            expectedExecutionTime.add(batch.expectedTime)
+        }
+    }
+
+    def dequeue() : Batch = {
+        if (this.isEmpty)
+            return Batch(-1, -1, -1, "NO NAME")
+        synchronized {
+            val nextBatch = this.poll()
+            expectedExecutionTime.add(-1*nextBatch.expectedTime)
+            return nextBatch
+        }
+    }
+
+    def getExpectedExecutionTime : Long = { expectedExecutionTime.longValue() }
+
+    def getExecType : String = { execType }
+}
+
+class JmlcExecutor(scheduler: Scheduler, execType: String, name: String, gCtx: 
GPUContext) extends Runnable {
+    @volatile protected var _shouldShutdown: Boolean = false
+    val LOG: Log = LogFactory.getLog(classOf[JmlcExecutor].getName)
+    var prevModel = ""
+
+    def shutdown(): Unit = {
+        _shouldShutdown = true
+    }
+
+    def getExecType: String = { execType }
+
+    def getName: String = { name }
+
+    def run(): Unit = {
+        Thread.sleep(1000)
+        while (!_shouldShutdown) {
+            val requests = scheduler.schedule(this)
+            if (requests.nonEmpty) {
+                val responses = execute(requests)
+                for ((req, resp) <- requests zip responses) {
+                    req.response = resp
+                    req.latch.countDown()
+                }
+            }
+        }
+    }
+
+    def execute(requests: Array[SchedulingRequest]): Array[PredictionResponse] 
= {
+        var responses = Array[PredictionResponse]()
+        if (requests.nonEmpty) {
+            try {
+                val start = System.nanoTime()
+                val batchedMatrixData = BatchingUtils.batchRequests(requests)
+                val batchingTime = System.nanoTime() - start
+                val req = requests(0)
+                LOG.info("Executing: " + req.model.name + " with batch size: " 
+ batchedMatrixData.getNumRows + " on " + name)
+                val modelAcquireStart = System.nanoTime()
+                val script = scheduler.modelManager.acquire(req.model.name, 
this)
+                script.setName(this.getName)
+                val modelAcquireTime = System.nanoTime() - modelAcquireStart
+                script.setMatrix(req.model.inputVarName, batchedMatrixData, 
false)
+                val execStart = System.nanoTime()
+                val res = 
script.executeScript().getMatrixBlock(req.model.outputVarName)
+                val execTime = System.nanoTime() - execStart
+                responses = BatchingUtils.unbatchRequests(requests, res)
+
+                val modelReleaseStart = System.nanoTime()
+                scheduler.modelManager.release(req.model.name)
+                scheduler.modelManager.releaseMemory(req.memUse)
+                val modelReleaseTime = System.nanoTime() - modelReleaseStart
+                scheduler.onCompleteCallback(req.model.name,
+                                             System.nanoTime() - 
req.receivedTime,
+                                             requests.length,
+                                             execType, System.nanoTime() - 
start)
+                if (req.statistics != null)
+                    setStatistics(requests, start, batchingTime, execTime, 
modelAcquireTime, modelReleaseTime)
+                if (prevModel.nonEmpty)
+                    scheduler.modelManager.unsetModelLocality(prevModel, this)
+                scheduler.modelManager.setModelLocality(req.model.name, this)
+                prevModel = req.model.name
+
+                LOG.info("Done executing request for: " + req.model.name + " 
on " + name)
+            } catch {
+                case e: Exception => println("AN ERROR OCCURRED: " + 
e.getMessage + e.printStackTrace())
+            }
+        }
+        responses
+    }
+
+    def setStatistics(requests: Array[SchedulingRequest],
+                      processingStartTime: Long,
+                      batchingTime: Long,
+                      execTime: Long,
+                      modelAcquireTime: Long,
+                      modelReleaseTime: Long): Unit = {
+        for (req <- requests) {
+            req.statistics.batchingTime = batchingTime
+            req.statistics.execType = getExecType
+            req.statistics.batchSize = requests.length
+            req.statistics.queueWaitTime = processingStartTime - 
req.receivedTime
+            req.statistics.execTime = execTime
+            req.statistics.modelAcquireTime = modelAcquireTime
+            req.statistics.modelReleaseTime = modelReleaseTime
+        }
+    }
+}
\ No newline at end of file
diff --git 
a/src/main/scala/org/apache/sysml/api/ml/serving/LocalityAwareScheduler.scala 
b/src/main/scala/org/apache/sysml/api/ml/serving/LocalityAwareScheduler.scala
new file mode 100644
index 0000000..61fc84f
--- /dev/null
+++ 
b/src/main/scala/org/apache/sysml/api/ml/serving/LocalityAwareScheduler.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+import java.util.concurrent.{ConcurrentHashMap, CountDownLatch}
+
+import org.apache.commons.logging.{Log, LogFactory}
+
+import scala.concurrent.Future
+import scala.math.min
+
+object ExecutorQueueManager extends Runnable {
+    val LOG: Log = LogFactory.getLog(ExecutorQueueManager.getClass.getName)
+    var _shutDown = false
+    var _scheduler = LocalityAwareScheduler
+    def shutdown(): Unit = { _shutDown = true }
+
+    override def run() : Unit = {
+        while (!_shutDown) {
+            _scheduler.dummyResponse.synchronized {
+                val schedulableModels = _scheduler.executorTypes.map(
+                    x => _scheduler.getSchedulableModels(x)).reduce(_ union _)
+                if (schedulableModels.nonEmpty) {
+                    for (m <- schedulableModels) {
+                        // every request batch can go to up to three queues
+
+                        // 1. Every batch goes to the global disk queue since 
the model might get evicted
+                        val diskQueues = _scheduler.executorTypes.map(x => 
_scheduler.globalDiskQueues.get(x))
+
+                        // 2. If the model is cached in memory, then also put 
it on the cache queue
+                        var cacheQueues = Array[BatchQueue]()
+                        if (_scheduler.modelManager.isCached(m))
+                            cacheQueues = _scheduler.executorTypes.map(x => 
_scheduler.globalCacheQueues.get(x))
+
+                        // 3. If the model is local to an executor, then put 
it on the lowest utilizaiton queue
+                        val localExecutionQueues = getLocalExecutionQueues(m)
+                        val localQueue = if (localExecutionQueues.nonEmpty)
+                            Array[BatchQueue](localExecutionQueues.minBy(x => 
x.getExpectedExecutionTime))
+                        else Array[BatchQueue]()
+
+                        val queues = diskQueues ++ cacheQueues ++ localQueue
+                        val nextRequest = _scheduler.modelQueues.get(m).peek()
+                        queues.foreach ( queue => {
+                            val qsize = _scheduler.modelQueues.get(m).size()
+                            if (nextRequest ne queue.getPrevRequest(m)) {
+                                val nextBatchSize = min(qsize, 
_scheduler.getOptimalBatchSize(m, queue.getExecType))
+                                assert(nextBatchSize > 0, "An error occurred - 
batch size should not be zero")
+                                LOG.debug("Enqueuing: " + nextBatchSize + " 
for: " + m + " onto: " + queue.getName)
+                                val nextBatch = Batch(
+                                    nextBatchSize, 
nextBatchSize*_scheduler.getExpectedExecutionTime(m),
+                                    nextRequest.receivedTime - 
System.nanoTime(), nextRequest.model.name)
+                                queue.enqueue(nextBatch)
+                                LOG.debug("Batch enqueued onto: " + 
queue.getName)
+                            }
+                            queue.updatePrevRequest(m, nextRequest) } )
+                        }
+                    }
+                }
+            }
+        }
+
+    def getLocalExecutionQueues(model: String) : Array[BatchQueue] = {
+        val execs = _scheduler.modelManager.getModelLocality(model)
+        var queues = Array[BatchQueue]()
+        if (execs == null)
+            return queues
+
+        _scheduler.modelManager.synchronized({
+            for (ix <- 0 until execs.size()) { 
_scheduler.executorQueues.get(execs.get(ix)) }
+        })
+
+        queues
+    }
+}
+
+object ExecMode extends Enumeration {
+    type MODE = Value
+    val LOCAL, GLOBAL_MEM, GLOBAL_DISK = Value
+}
+
+object LocalityAwareScheduler extends BatchingScheduler {
+    var queueManager : Thread = _
+
+    val globalCacheQueues = new ConcurrentHashMap[String, BatchQueue]()
+    val globalDiskQueues = new ConcurrentHashMap[String, BatchQueue]()
+
+    override def start(numCores: Int, cpuMemoryBudgetInBytes: Long, gpus: 
String): Unit = {
+        super.start(numCores, cpuMemoryBudgetInBytes, gpus)
+
+        executorTypes.foreach ( x => {
+            globalCacheQueues.putIfAbsent(x, new BatchQueue(x, x + "-CACHE"))
+            globalDiskQueues.putIfAbsent(x, new BatchQueue(x, x + "-DISK"))
+        } )
+
+        queueManager = new Thread(ExecutorQueueManager)
+        queueManager.start()
+    }
+
+    override def addModel(model: Model): Unit = {
+        super.addModel(model)
+    }
+
+    override def schedule(executor: JmlcExecutor) : Array[SchedulingRequest] = 
{
+        var ret = Array[SchedulingRequest]()
+        val localQueue = executorQueues.get(executor)
+        val globalDiskQueue = globalDiskQueues.get(executor.getExecType)
+        val globalMemQueue = globalCacheQueues.get(executor.getExecType)
+        if (localQueue.size() > 0 || globalDiskQueue.size() > 0 || 
globalMemQueue.size() > 0) {
+            dummyResponse.synchronized {
+                if (localQueue.size() > 0 || globalDiskQueue.size() > 0 || 
globalMemQueue.size() > 0) {
+                    LOG.debug("Begin scheduling for executor: " + 
executor.getName)
+                    val execMode = Array[(BatchQueue, ExecMode.MODE)](
+                        (localQueue, ExecMode.LOCAL),
+                        (globalDiskQueue, ExecMode.GLOBAL_DISK),
+                        (globalMemQueue, ExecMode.GLOBAL_MEM)
+                    ).filter(x => x._1.size() > 0).maxBy(x => 
x._1.getExpectedExecutionTime)._2
+
+                    val batch = execMode match {
+                        case ExecMode.LOCAL => localQueue.peek()
+                        case ExecMode.GLOBAL_MEM => globalMemQueue.peek()
+                        case ExecMode.GLOBAL_DISK => globalDiskQueue.peek()
+                    }
+                    assert(batch != null, "Something is wrong. Batch should 
not be null!")
+
+                    // now we need to ask the resource manager if there's 
enough memory to execute the batch
+                    val model = modelManager.get(batch.modelName)
+
+                    // If there's enough memory we can actually remove the 
requests from the queue and
+                    // submit them for processing
+                    val mqueue = modelQueues.get(batch.modelName)
+                    val numToDequeue = min(batch.size, mqueue.size())
+
+                    // if this value is zero there are no more requests and 
the batch is stale
+                    if (numToDequeue == 0) {
+                        execMode match {
+                            case ExecMode.LOCAL => localQueue.poll()
+                            case ExecMode.GLOBAL_DISK => globalDiskQueue.poll()
+                            case ExecMode.GLOBAL_MEM => globalMemQueue.poll()
+                        }
+                    } else {
+                        val memReceived = modelManager.tryAllocMem(model.name, 
batch.size)
+                        if (memReceived < 0) {
+                            return ret
+                        }
+
+                        // now we need to actually remove the request from the 
queue since it's going to be processed
+                        execMode match {
+                            case ExecMode.LOCAL => localQueue.poll()
+                            case ExecMode.GLOBAL_DISK => globalDiskQueue.poll()
+                            case ExecMode.GLOBAL_MEM => globalMemQueue.poll()
+                        }
+
+                        // now we can actually take the original requests out 
of the model queues
+                        LOG.debug("Scheduling: " + numToDequeue + " for " + 
batch.modelName + " on " + executor.getName)
+                        for (_ <- 0 until numToDequeue) {
+                            val nextRequest = mqueue.poll()
+                            assert(nextRequest != null, "Something is wrong - 
request should not be null!")
+
+                            nextRequest.memUse = memReceived
+                            nextRequest.statistics.execMode = execMode match {
+                                case ExecMode.LOCAL => 0
+                                case ExecMode.GLOBAL_MEM => 1
+                                case ExecMode.GLOBAL_DISK => 2
+                                case _ => -1
+                            }
+                            ret :+= nextRequest
+                        }
+                        LOG.debug("Done scheduling on: " + executor.getName)
+                    }
+                }
+            }
+        }
+        ret
+    }
+
+    /**
+      * Enqueues a request for processing. The scheduler will read from these 
queues to determine which
+      * models to execute next
+      * @param request A PredictionRequest object containing the data for 
which a prediction is desired
+      * @param model The model object for which prediction
+      * @return
+      */
+    override private[serving] def enqueue(request: PredictionRequest, model: 
Model): Future[PredictionResponse] = Future {
+        val statistics = if (_statistics) RequestStatistics() else null
+        val schedulingRequest = SchedulingRequest(
+            request, model, new CountDownLatch(1), System.nanoTime(), null, 
statistics)
+
+        if (_statistics) {
+            statistics.queueSize = modelQueues.get(model.name).size
+            statistics.preprocWaitTime = System.nanoTime() - 
request.receivedTime
+        }
+
+        modelQueues.get(model.name).add(schedulingRequest)
+
+        try {
+            schedulingRequest.latch.await(timeout.length, timeout.unit)
+            schedulingRequest.response
+        } catch {
+            case _ : scala.concurrent.TimeoutException => dummyResponse
+        }
+    }
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/ModelManager.scala 
b/src/main/scala/org/apache/sysml/api/ml/serving/ModelManager.scala
new file mode 100644
index 0000000..b67d367
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/ModelManager.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+import java.util
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.LongAdder
+
+import org.apache.commons.logging.{Log, LogFactory}
+import org.apache.sysml.api.jmlc.{Connection, PreparedScript}
+import org.apache.sysml.runtime.matrix.data.MatrixBlock
+import org.apache.sysml.utils.PersistentLRUCache
+
+trait ModelManager {
+    val LOG: Log = LogFactory.getLog(classOf[ModelManager].getName)
+    var modelLocality = new ConcurrentHashMap[String, 
util.ArrayList[JmlcExecutor]]()
+    val conn: Connection = new Connection()
+    val availableMemory = new LongAdder
+    var totalMemory = 0L
+    var cleanupEnabled = true
+    var memCheckEnabled = true
+    var models: Map[String, Model] = Map()
+
+    def setAvailableMemory(memBytes: Long) : Unit = {
+        LOG.info("Setting total memory to: " + memBytes + " bytes")
+        totalMemory = memBytes
+        availableMemory.reset()
+        availableMemory.add(memBytes)
+    }
+
+    def getAvailableMemory : Long = { availableMemory.longValue() }
+
+    def acquireMemory(bytes: Long) : Long = {
+        // if memory checking is not enabled just always say they get the 
memory
+        if (!memCheckEnabled || bytes == 0)
+            return bytes
+        LOG.debug("Requested: " + bytes)
+
+        // otherwise check to see if there is enough memory to meet the request
+        if (bytes <= availableMemory.longValue()) {
+            availableMemory.add(-1 * bytes)
+            LOG.debug("Granted: " + bytes + "/" + availableMemory.longValue())
+            return bytes
+        }
+        // not enough memory available :(
+
+        LOG.debug("Insufficient memory. Request was not granted")
+        -1
+    }
+
+    def releaseMemory(bytes: Long) : Unit = {
+        if (bytes > 0) {
+            LOG.debug("Releasing: " + bytes)
+            availableMemory.add(bytes)
+            LOG.debug("Available memory is now: " + 
availableMemory.longValue())
+        }
+    }
+
+    def setModelLocality(model: String, exec: JmlcExecutor) : Unit = {
+        this.synchronized({
+            modelLocality.putIfAbsent(model, new 
util.ArrayList[JmlcExecutor]())
+            modelLocality.get(model).add(exec)
+        })
+    }
+
+    def unsetModelLocality(model: String, exec: JmlcExecutor) : Unit = {
+        this.synchronized({ modelLocality.get(model).remove(exec) })
+    }
+
+    def getModelLocality(model: String) : util.ArrayList[JmlcExecutor] = { 
modelLocality.get(model) }
+
+    def isModelLocal(model: String, exec: JmlcExecutor) : Boolean = { 
getModelLocality(model).contains(exec) }
+
+    def disableCleanup() : Unit = { cleanupEnabled = false }
+
+    def disableMemcheck() : Unit = { memCheckEnabled = false }
+
+    def put(model: Model): Unit
+
+    def get(name: String): Model
+
+    def putWeight(name: String, weight: MatrixBlock) : Unit
+
+    def acquire(name: String, executor: JmlcExecutor) : PreparedScript
+
+    def release(name: String) : Unit
+}
+
+object ReferenceCountedModelManager extends ModelManager {
+    var modelRefCounts: Map[String,LongAdder] = Map()
+    var weightCache : PersistentLRUCache = _
+
+    override def setAvailableMemory(maxBytes: Long) : Unit = {
+        super.setAvailableMemory(maxBytes)
+        weightCache = new PersistentLRUCache((0.80*maxBytes).toLong)
+        weightCache.enableReadOnlyMode(true)
+    }
+
+    def tryAllocMem(name: String, batchSize: Int) : Long = {
+        // TODO: More sophisticated memory management
+        val extraMem = (0.5*models(name).weightMem).toLong
+        val weightMem = if (modelRefCounts(name).longValue() > 0) 0L else 
models(name).weightMem
+        val memReceived = acquireMemory(extraMem + weightMem)
+        if (memReceived < 0) memReceived else extraMem
+    }
+
+    def isCached(name: String) : Boolean = { modelRefCounts(name).longValue() 
> 0 }
+
+    def acquire(name: String, executor: JmlcExecutor) : PreparedScript = {
+         LOG.debug("Acquiring model: " + name + " Ref count: " + 
modelRefCounts(name).longValue())
+
+        val execName = if (executor.getExecType == "GPU") executor.getName 
else executor.getExecType
+        val ps = models(name).script(execName)
+        if (modelRefCounts(name).longValue() > 0 && ps.hasPinnedData) {
+            modelRefCounts(name).increment()
+            return ps.clone(false)
+        }
+
+        // otherwise we need to re-pin the weights, possibly reading them from 
disk
+        val model = models(name)
+        model.synchronized {
+            LOG.debug("Pinning weights for: " + name)
+            model.weightFiles.foreach(x => ps.setMatrix(x._1, 
weightCache.getAsMatrixBlock(x._2), true))
+            modelRefCounts(name).increment()
+        }
+        LOG.debug("Done acquiring model: " + name)
+        ps.clone(false)
+    }
+
+    override def disableCleanup(): Unit = {
+        super.disableCleanup()
+        LOG.debug("Cleanup is disabled")
+    }
+
+    def release(name: String) : Unit = {
+        modelRefCounts(name).decrement()
+        releaseMemory(models(name).weightMem)
+
+        LOG.debug("Releasing model: " + name + " Ref count: " + 
modelRefCounts(name).longValue())
+        if (modelRefCounts(name).longValue() == 0) {
+            models(name).script.synchronized {
+                if (modelRefCounts(name).longValue() == 0) {
+                    models(name).script.foreach { x => x._2.clearPinnedData() }
+                }
+            }
+        }
+    }
+
+    def put(model: Model) : Unit = {
+        models += (model.name -> model)
+        modelRefCounts += (model.name -> new LongAdder())
+    }
+
+    def putWeight(name: String, weight: MatrixBlock) : Unit = {
+        weightCache.put(name, weight)
+    }
+
+    def get(name: String) : Model = { models(name) }
+
+}
\ No newline at end of file
diff --git 
a/src/main/scala/org/apache/sysml/api/ml/serving/NonBatchingScheduler.scala 
b/src/main/scala/org/apache/sysml/api/ml/serving/NonBatchingScheduler.scala
new file mode 100644
index 0000000..44ff26f
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/NonBatchingScheduler.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.atomic.LongAdder
+
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
+
+object NonBatchingScheduler extends Scheduler {
+
+    override def start(numCores: Int, cpuMemoryBudgetInBytes: Long, gpus: 
String): Unit = {
+        LOG.info(s"Starting Non Batching Scheduler with: ${numCores} CPUs and 
${gpus} GPUs")
+        super.start(numCores, cpuMemoryBudgetInBytes, gpus)
+    }
+
+    override def schedule(executor: JmlcExecutor): Array[SchedulingRequest] = {
+        var ret = Array[SchedulingRequest]()
+        dummyResponse.synchronized {
+            if (requestQueue.size() > 0) {
+                val request = requestQueue.poll()
+                ret :+= request
+            }
+        }
+        ret
+    }
+
+    var requestNum = new LongAdder
+    /**
+      * Enqueues a request for processing. The scheduler will read from these 
queues to determine which
+      * models to execute next
+      * @param request A PredictionRequest object containing the data for 
which a prediction is desired
+      * @param model The model object for which prediction is desired
+      * @return
+      */
+    override private[serving] def enqueue(request: PredictionRequest, model: 
Model): Future[PredictionResponse] = Future {
+        val statistics = if (_statistics) RequestStatistics() else null
+        val schedulingRequest = SchedulingRequest(
+            request, model, new CountDownLatch(1), System.nanoTime(), null, 
statistics)
+        if (_statistics) statistics.queueSize = requestQueue.size()
+        requestQueue.add(schedulingRequest)
+        counter += 1
+        try {
+            schedulingRequest.latch.await(timeout.length, timeout.unit)
+            schedulingRequest.response
+        } catch {
+            case e : scala.concurrent.TimeoutException => dummyResponse
+        }
+    }
+
+    override def onCompleteCallback(model: String, latency: Double, batchSize: 
Int, execType: String, execTime: Long): Unit = {}
+}
\ No newline at end of file
diff --git 
a/src/main/scala/org/apache/sysml/api/ml/serving/PredictionService.scala 
b/src/main/scala/org/apache/sysml/api/ml/serving/PredictionService.scala
new file mode 100644
index 0000000..f8c2345
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/PredictionService.scala
@@ -0,0 +1,490 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+import java.io.File
+
+import akka.http.scaladsl.server.StandardRoute
+import akka.http.scaladsl.server.Directives._
+import akka.http.scaladsl.model.StatusCodes
+import akka.http.scaladsl.Http
+import akka.actor.ActorSystem
+import akka.stream.ActorMaterializer
+import org.apache.commons.cli.PosixParser
+import com.typesafe.config.ConfigFactory
+
+import scala.concurrent.duration._
+import java.util.HashMap
+
+import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
+import spray.json._
+import java.util.concurrent.atomic.LongAdder
+
+import scala.concurrent.{Await, Future}
+import scala.math.{max, pow}
+import org.apache.sysml.runtime.matrix.data.{MatrixBlock, OutputInfo}
+import org.apache.sysml.parser.DataExpression
+import org.apache.sysml.runtime.io.IOUtilFunctions
+import org.apache.sysml.api.jmlc.Connection
+import org.apache.sysml.api.jmlc.PreparedScript
+import org.apache.sysml.conf.ConfigurationManager
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics
+import org.apache.sysml.runtime.util.DataConverter
+import org.apache.commons.logging.Log
+import org.apache.commons.logging.LogFactory
+
+import scala.concurrent.ExecutionContext
+
+// format: can be file, binary, csv, ijv, jpeg, ...
+
+case class RequestStatistics(var batchSize: Int = -1,
+                             var execTime: Long = -1,
+                             var execType: String = "",
+                             var requestDeserializationTime: Long = -1,
+                             var responseSerializationTime: Long = -1,
+                             var modelAcquireTime: Long = -1,
+                             var modelReleaseTime: Long = -1,
+                             var batchingTime: Long = -1,
+                             var unbatchingTime: Long = -1,
+                             var queueWaitTime: Long = -1,
+                             var queueSize: Int = -1,
+                             var execMode: Int = 0,
+                             var preprocWaitTime: Long = -1)
+case class PredictionRequestExternal(name: String, data: Array[Double], rows: 
Int, cols: Int)
+case class PredictionResponseExternal(response: Array[Double], rows: Int, 
cols: Int, statistics: RequestStatistics)
+
+case class AddModelRequest(name: String, dml: String, inputVarName: String,
+                           outputVarName: String, weightsDir: String,
+                           latencyObjective: String, batchSize: Array[Int], 
memUse: Array[Long])
+
+case class Model(name: String,
+                 script: Map[String,PreparedScript],
+                 inputVarName: String,
+                 outputVarName: String,
+                 latencyObjective: Duration,
+                 weightFiles: Map[String, String],
+                 coeffs: (Double, Double),
+                 weightMem: Long)
+case class PredictionRequest(data : MatrixBlock, modelName : String, 
requestSize : Int, receivedTime : Long)
+case class PredictionResponse(response: MatrixBlock, batchSize: Int, 
statistics: RequestStatistics)
+case class MatrixBlockContainer(numRows: Long, numCols: Long, nnz: Long, sum: 
Double, data: MatrixBlock)
+
+trait PredictionJsonProtocol extends SprayJsonSupport with DefaultJsonProtocol 
{
+    implicit val RequestStatisticsFormat = jsonFormat13(RequestStatistics)
+    implicit val predictionRequestExternalFormat = 
jsonFormat4(PredictionRequestExternal)
+    implicit val predictionResponseExternalFormat = 
jsonFormat4(PredictionResponseExternal)
+}
+
+trait AddModelJsonProtocol extends SprayJsonSupport with DefaultJsonProtocol {
+    implicit val AddModelRequetFormat = jsonFormat8(AddModelRequest)
+}
+
+class PredictionService {
+
+}
+
+/*
+Usage:
+1. Compiling a fat jar with maven assembly plugin in our standalone jar 
created lot of issues. 
+Hence, for time being, we recommend downloading jar using the below script:
+SCALA_VERSION="2.11"
+AKKA_HTTP_VERSION="10.1.3"
+AKKA_VERSION="2.5.14"
+PREFIX="http://central.maven.org/maven2/com/typesafe/akka/";
+JARS=""
+for PKG in actor stream protobuf
+do
+  PKG_NAME="akka-"$PKG"_"$SCALA_VERSION
+  JAR_FILE=$PKG_NAME"-"$AKKA_VERSION".jar"
+  wget $PREFIX$PKG_NAME"/"$AKKA_VERSION"/"$JAR_FILE
+  JARS=$JARS$JAR_FILE":"
+done
+for PKG in http http-core parsing
+do
+  PKG_NAME="akka-"$PKG"_"$SCALA_VERSION
+  JAR_FILE=$PKG_NAME"-"$AKKA_HTTP_VERSION".jar"
+  wget $PREFIX$PKG_NAME"/"$AKKA_HTTP_VERSION"/"$JAR_FILE
+  JARS=$JARS$JAR_FILE":"
+done
+wget http://central.maven.org/maven2/com/typesafe/config/1.3.3/config-1.3.3.jar
+wget 
http://central.maven.org/maven2/com/typesafe/ssl-config-core_2.11/0.2.4/ssl-config-core_2.11-0.2.4.jar
+wget 
http://central.maven.org/maven2/org/reactivestreams/reactive-streams/1.0.2/reactive-streams-1.0.2.jar
+wget 
http://central.maven.org/maven2/org/scala-lang/scala-library/2.11.12/scala-library-2.11.12.jar
+wget 
http://central.maven.org/maven2/org/scala-lang/scala-parser-combinators/2.11.0-M4/scala-parser-combinators-2.11.0-M4.jar
+wget 
http://central.maven.org/maven2/commons-cli/commons-cli/1.4/commons-cli-1.4.jar
+wget 
http://central.maven.org/maven2/com/typesafe/akka/akka-http-spray-json-experimental_2.11/2.4.11.2/akka-http-spray-json-experimental_2.11-2.4.11.2.jar
+wget 
http://central.maven.org/maven2/io/spray/spray-json_2.11/1.3.2/spray-json_2.11-1.3.2.jar
+JARS=$JARS"config-1.3.3.jar:ssl-config-core_2.11-0.2.4.jar:reactive-streams-1.0.2.jar:commons-cli-1.4.jar:scala-parser-combinators-2.11.0-M4.jar:scala-library-2.11.12.jar:akka-http-spray-json-experimental_2.11-2.4.11.2.jar:spray-json_2.11-1.3.2.jar"
+echo "Include the following jars into the classpath: "$JARS
+
+
+2. Copy SystemML.jar and systemml-1.2.0-SNAPSHOT-extra.jar into the directory 
where akka jars are placed
+
+3. Start the server:
+java -cp $JARS org.apache.sysml.api.ml.serving.PredictionService -port 9000 
-admin_password admin
+
+4. Check the health of the server:
+curl -u admin -XGET localhost:9000/health
+
+5. Perform prediction
+curl -XPOST -H "Content-Type:application/json" -d '{ "inputs":"1,2,3", 
"format":"csv", "model":"test", "num_input":1 }' localhost:9000/predict
+
+6. Shutdown the server:
+curl -u admin -XGET localhost:9000/shutdown
+
+ */
+
+object PredictionService extends PredictionJsonProtocol with 
AddModelJsonProtocol {
+    val __DEBUG__ = false
+
+    val LOG = LogFactory.getLog(classOf[PredictionService].getName)
+    val customConf = ConfigFactory.parseString("""
+        akka.http.server.idle-timeout=infinite
+        akka.http.client.idle-timeout=infinite
+        akka.http.host-connection-pool.idle-timeout=infinite
+        akka.http.host-connection-pool.client.idle-timeout=infinite
+        akka.http.server.max-connections=100000
+    """)
+    val basicConf = ConfigFactory.load()
+    val combined = customConf.withFallback(basicConf)
+    implicit val system = ActorSystem("systemml-prediction-service", 
ConfigFactory.load(combined))
+    implicit val materializer = ActorMaterializer()
+    implicit val executionContext = ExecutionContext.global
+    implicit val timeout = akka.util.Timeout(300.seconds)
+    val userPassword = new HashMap[String, String]()
+    var bindingFuture: Future[Http.ServerBinding] = null
+    var scheduler: Scheduler = null
+    val conn = new Connection()
+    var existantMatrixBlocks = Array[MatrixBlockContainer]()
+
+    def getCommandLineOptions(): org.apache.commons.cli.Options = {
+        val hostOption = new org.apache.commons.cli.Option("ip", true, "IP 
address")
+        val portOption = new org.apache.commons.cli.Option("port", true, "Port 
number")
+        val numRequestOption = new 
org.apache.commons.cli.Option("max_requests", true, "Maximum number of 
requests")
+        val timeoutOption = new org.apache.commons.cli.Option("timeout", true, 
"Timeout in milliseconds")
+        val passwdOption = new org.apache.commons.cli.Option("admin_password", 
true, "Admin password. Default: admin")
+        val helpOption = new org.apache.commons.cli.Option("help", false, 
"Show usage message")
+        val maxSizeOption = new org.apache.commons.cli.Option("max_bytes", 
true, "Maximum size of request in bytes")
+        val statisticsOption = new org.apache.commons.cli.Option("statistics", 
true, "Gather statistics on request execution")
+        val numCpuOption = new org.apache.commons.cli.Option("num_cpus", true, 
"How many CPUs should be allocated to the prediction service. Default nproc-1")
+        val gpusOption = new org.apache.commons.cli.Option("gpus", true, "GPUs 
available to this process. Default: 0")
+        val schedulerOption = new org.apache.commons.cli.Option("scheduler", 
true, "Scheduler implementation to use. Default: locality-aware")
+
+        // Only port is required option
+        portOption.setRequired(true)
+
+        return new org.apache.commons.cli.Options()
+          
.addOption(hostOption).addOption(portOption).addOption(numRequestOption)
+          
.addOption(passwdOption).addOption(timeoutOption).addOption(helpOption)
+          
.addOption(maxSizeOption).addOption(statisticsOption).addOption(numCpuOption)
+          .addOption(gpusOption).addOption(schedulerOption)
+    }
+
+    def main(args: Array[String]): Unit = {
+        // Parse commandline variables:
+        val options = getCommandLineOptions
+        val line = new PosixParser().parse(getCommandLineOptions, args)
+        if (line.hasOption("help")) {
+            new 
org.apache.commons.cli.HelpFormatter().printHelp("systemml-prediction-service", 
options)
+            return
+        }
+        userPassword.put("admin", line.getOptionValue("admin_password", 
"admin"))
+        val currNumRequests = new LongAdder
+        val maxNumRequests = if (line.hasOption("max_requests"))
+            line.getOptionValue("max_requests").toLong else Long.MaxValue
+        val timeout = if (line.hasOption("timeout"))
+            Duration(line.getOptionValue("timeout").toLong, MILLISECONDS) else 
300.seconds
+        val sizeDirective = if (line.hasOption("max_bytes"))
+            withSizeLimit(line.getOptionValue("max_bytes").toLong) else 
withoutSizeLimit
+        val numCores = if (line.hasOption("num_cpus"))
+            line.getOptionValue("num_cpus").toInt else 
Runtime.getRuntime.availableProcessors() - 1
+        val gpus = if (line.hasOption("gpus")) line.getOptionValue("gpus") 
else null
+        val schedulerType = line.getOptionValue("scheduler", "locality-aware")
+
+        // Initialize statistics counters
+        val numTimeouts = new LongAdder
+        val numFailures = new LongAdder
+        val totalTime = new LongAdder
+        val numCompletedPredictions = new LongAdder
+
+        // For now the models need to be loaded every time. TODO: pass the 
local to serialized models via commandline
+        var models = Map[String, Model]()
+
+        // TODO: Set the scheduler using factory
+        scheduler = SchedulerFactory.getScheduler(schedulerType)
+        val maxMemory = Runtime.getRuntime.maxMemory()  // total memory is 
just what the JVM has currently allocated
+
+        LOG.info("Total memory allocated to server: " + maxMemory)
+        scheduler.start(numCores, maxMemory, gpus)
+
+        // Define unsecured routes: /predict and /health
+        val unsecuredRoutes = {
+            path("predict") {
+                withoutRequestTimeout {
+                    post {
+                        validate(currNumRequests.longValue() < maxNumRequests, 
"The prediction server received too many requests. Ignoring the current 
request.") {
+                            entity(as[PredictionRequestExternal]) { request =>
+                                validate(models.contains(request.name), "The 
model is not available.") {
+                                    try {
+                                        currNumRequests.increment()
+                                        val start = System.nanoTime()
+                                        val processedRequest = 
processPredictionRequest(request)
+                                        val deserializationTime = 
System.nanoTime() - start
+
+                                        val response = Await.result(
+                                            
scheduler.enqueue(processedRequest, models(request.name)), timeout)
+                                        totalTime.add(System.nanoTime() - 
start)
+
+                                        numCompletedPredictions.increment()
+                                        complete(StatusCodes.OK, 
processPredictionResponse(response, "NOT IMPLEMENTED", deserializationTime))
+                                    } catch {
+                                        case e: 
scala.concurrent.TimeoutException => {
+                                            numTimeouts.increment()
+                                            
complete(StatusCodes.RequestTimeout, "Timeout occured")
+                                        }
+                                        case e: Exception => {
+                                            numFailures.increment()
+                                            e.printStackTrace()
+                                            val msg = "Exception occured while 
executing the prediction request:"
+                                            
complete(StatusCodes.InternalServerError, msg + e.getMessage)
+                                        }
+                                    } finally {
+                                        currNumRequests.decrement()
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+            } ~ path("health") {
+                get {
+                    val stats = "Number of requests 
(total/completed/timeout/failures):" + currNumRequests.longValue() + "/" + 
numCompletedPredictions.longValue() + "/"
+                    numTimeouts.longValue() + "/" + numFailures.longValue() + 
".\n" +
+                      "Average prediction time:" + ((totalTime.doubleValue() * 
1e-6) / numCompletedPredictions.longValue()) + " ms.\n"
+                    complete(StatusCodes.OK, stats)
+                }
+            }
+        }
+
+        // For administration: This can be later extended for supporting 
multiple users.
+        val securedRoutes = {
+            authenticateBasicAsync(realm = "secure site", userAuthenticate) {
+                user =>
+                    path("shutdown") {
+                        get {
+                            shutdownService(user, scheduler)
+                        }
+                    } ~
+                      path("register-model") {
+                          withoutRequestTimeout {
+                              post {
+                                  entity(as[AddModelRequest]) { request =>
+                                      validate(!models.contains(request.name), 
"The model is already loaded") {
+                                          try {
+                                              val weightsInfo = 
processWeights(request.weightsDir)
+                                              val inputs = 
weightsInfo._1.keys.toArray ++ Array[String](request.inputVarName)
+
+                                              // compile for executor types
+                                              val scriptCpu = 
conn.prepareScript(
+                                                  request.dml, inputs, 
Array[String](request.outputVarName))
+                                              var scripts = Map("CPU" -> 
scriptCpu)
+
+                                              if (gpus != null) {
+                                                  
GPUContextPool.AVAILABLE_GPUS = gpus
+                                                  for (ix <- 0 until 
GPUContextPool.getAvailableCount) {
+                                                      LOG.info("Compiling 
script for GPU: " + ix)
+                                                      scripts += (s"GPU${ix}" 
-> conn.prepareScript(
+                                                          request.dml, inputs, 
Array[String](request.outputVarName),
+                                                          true, true, ix))
+                                                  }
+                                              }
+
+                                              // b = cov(x,y) / var(x)
+                                              // a = mean(y) - b*mean(x)
+                                              val n = 
max(request.batchSize.length, 1).toDouble
+                                              val x = request.batchSize
+                                              val y = request.memUse
+                                              val mux = x.sum / n
+                                              val muy = y.sum / n
+                                              val vx = (1 / n) * x.map(v => 
pow(v - mux, 2.0)).sum
+                                              val b = ((1 / n) * (x.map(v => v 
- mux) zip y.map(v => v - muy)
+                                                ).map(v => v._1 * v._2).sum) * 
(1 / vx)
+                                              val a = muy - b * mux
+
+                                              // now register the created model
+                                              val model = Model(request.name,
+                                                  scripts,
+                                                  request.inputVarName,
+                                                  request.outputVarName,
+                                                  
Duration(request.latencyObjective),
+                                                  weightsInfo._1, (a, b), 
weightsInfo._2)
+                                              models += (request.name -> model)
+                                              scheduler.addModel(model)
+                                              complete(StatusCodes.OK)
+                                          } catch {
+                                              case e: Exception => {
+                                                  numFailures.increment()
+                                                  e.printStackTrace()
+                                                  
complete(StatusCodes.InternalServerError,
+                                                      "Exception occured while 
trying to add model:" + e.getMessage)
+                                              }
+                                          }
+                                      }
+                                  }
+                              }
+                          }
+                      }
+            }
+        }
+
+        bindingFuture = Http().bindAndHandle(
+            sizeDirective { // Both secured and unsecured routes need to 
respect the size restriction
+                unsecuredRoutes ~ securedRoutes
+            },
+            line.getOptionValue("ip", "localhost"), 
line.getOptionValue("port").toInt)
+
+        println(s"Prediction Server online.")
+        while (true) Thread.sleep(100)
+        bindingFuture
+          .flatMap(_.unbind())
+          .onComplete(_ ⇒ system.terminate())
+    }
+
+    def processPredictionResponse(response : PredictionResponse, 
+                                  format : String, 
+                                  deserializationTime: Long) : 
PredictionResponseExternal = {
+        if (response != null) {
+            val start = System.nanoTime()
+            val dataArray = response.response.getDenseBlockValues
+            val rows = response.response.getNumRows
+            val cols = response.response.getNumColumns
+            val serializationTime = System.nanoTime() - start
+            if (response.statistics != null) {
+                response.statistics.requestDeserializationTime = 
deserializationTime
+                response.statistics.responseSerializationTime = 
serializationTime
+            }
+            PredictionResponseExternal(dataArray, rows, cols, 
response.statistics)
+        } else {
+            PredictionResponseExternal(null, -1, -1, null)
+        }
+    }
+
+    def processWeights(dirname: String) : (Map[String, String], Long) = {
+        val dir = new File(dirname)
+        if (!(dir.exists && dir.isDirectory))
+            throw new Exception("Weight directory: " + dirname + " is invalid")
+
+        val weightsWithSize = dir.listFiles().filter(
+            x => !(x.isDirectory && (x.toString contains 
"binary"))).map(_.toString).filter(
+            x => (x.slice(x.length-3, x.length) != "mtd") &&
+            !(x contains "_bin.mtx")).
+          map(x => getNameFromPath(x) -> registerWeight(x, dirname)).toMap
+
+        val weightMap = weightsWithSize.map(x => x._1 -> x._2._1)
+        val totalSize = weightsWithSize.map(x => x._2._2).sum
+
+        (weightMap, totalSize)
+    }
+
+    def getNameFromPath(path: String) : String = {
+        path.split("/").last.split("\\.")(0)
+    }
+
+    def registerWeight(path: String, dir: String) : (String, Long) = {
+        val res = convertToBinaryIfNecessary(path, dir)
+        scheduler.modelManager.putWeight(res._2, res._1)
+        (res._2, res._1.getInMemorySize)
+    }
+
+    def convertToBinaryIfNecessary(path: String, dir: String) : (MatrixBlock, 
String) = {
+        var pathActual = path
+        LOG.info("Reading weight: " + path)
+        val data = conn.readMatrix(path)
+
+        if (!isBinaryFormat(path)) {
+            LOG.info("Converting weight to binary format")
+            data.getMatrixCharacteristics
+            val binPath = dir + "/binary/" + getNameFromPath(path) + ".mtx"
+            DataConverter.writeMatrixToHDFS(data, binPath,
+                OutputInfo.BinaryBlockOutputInfo,
+                new MatrixCharacteristics(data.getNumRows, data.getNumColumns, 
ConfigurationManager.getBlocksize,
+                    ConfigurationManager.getBlocksize, data.getNonZeros))
+            pathActual = binPath
+        }
+        (data, pathActual)
+    }
+
+    def isBinaryFormat(path: String) : Boolean = {
+        val mtdName = DataExpression.getMTDFileName(path)
+        val mtd = new DataExpression().readMetadataFile(mtdName, false)
+        if (mtd.containsKey("format")) mtd.getString("format") == "binary" 
else false
+    }
+
+    def processPredictionRequest(request : PredictionRequestExternal) : 
PredictionRequest = {
+        val mat = new MatrixBlock(request.rows, request.cols, false)
+        mat.init(request.data, request.rows, request.cols)
+        PredictionRequest(mat, request.name, request.rows, System.nanoTime())
+    }
+
+    def processMatrixInput(data : String, rows : Int, cols : Int, format : 
String) : MatrixBlock = {
+        val result = format match {
+            case "csv" => processTextInput(data, rows, cols, 
DataExpression.FORMAT_TYPE_VALUE_CSV)
+            case _ => throw new Exception("Only CSV Input currently supported")
+        }
+        result
+    }
+
+    def processTextInput(data : String, rows : Int, cols : Int, format : 
String) : MatrixBlock = {
+        val is = IOUtilFunctions.toInputStream(data)
+        conn.convertToMatrix(is, rows, cols, format)
+    }
+
+    def userAuthenticate(credentials: 
akka.http.scaladsl.server.directives.Credentials): Future[Option[String]] = {
+        credentials match {
+            case 
p...@akka.http.scaladsl.server.directives.Credentials.Provided(id) =>
+                Future {
+                    if (userPassword.containsKey(id) && 
p.verify(userPassword.get(id))) Some(id)
+                    else None
+                }
+            case _ => Future.successful(None)
+        }
+    }
+
+    def shutdownService(user: String, scheduler: Scheduler): StandardRoute = {
+        if (user.equals("admin")) {
+            try {
+                Http().shutdownAllConnectionPools() andThen { case _ => 
bindingFuture.flatMap(_.unbind()).onComplete(_ ⇒ system.terminate()) }
+                scheduler.shutdown()
+                complete(StatusCodes.OK, "Shutting down the server.")
+            } finally {
+                new Thread(new Runnable {
+                    def run() {
+                        Thread.sleep(100) // wait for 100ms to send reply and 
then kill the prediction JVM so that we don't wait scala.io.StdIn.readLine()
+                        System.exit(0)
+                    }
+                }).start();
+            }
+        }
+        else {
+            complete(StatusCodes.BadRequest, "Only admin can shutdown the 
service.")
+        }
+    }
+
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/RLSEstimator.scala 
b/src/main/scala/org/apache/sysml/api/ml/serving/RLSEstimator.scala
new file mode 100644
index 0000000..03bc8cd
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/RLSEstimator.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+import java.util.concurrent.LinkedBlockingQueue
+
+import breeze.linalg._
+import breeze.numerics._
+import breeze.stats._
+
+class RLSEstimator {
+    val dataQueue = new LinkedBlockingQueue[(Double, Double)]()
+    val chunkSize = 2
+
+    var isInitialized = false
+    var isFinalized = false
+    var Q : DenseMatrix[Double] = _
+    var b : DenseMatrix[Double] = _
+    var n = 0
+    val lda = 0.98
+    val eps = 0.00000001
+    var sigma = -1.0
+
+    def enqueueExample(batchSize: Int, latency: Double) : Unit = {
+        if (!isFinalized) {
+            println("ENQUEUING => " + dataQueue.size())
+            dataQueue.add((batchSize.toDouble, latency))
+            if (dataQueue.size() >= chunkSize)
+                update()
+        }
+    }
+
+    def dequeueExamples() : (DenseMatrix[Double], DenseMatrix[Double]) = {
+        val X = DenseMatrix.zeros[Double](chunkSize,4)
+        val y = DenseMatrix.zeros[Double](chunkSize, 1)
+
+        for (ix <- 0 until chunkSize) {
+            val (x_ex, y_ex) = dataQueue.poll()
+            X(ix,::) := DenseVector[Double](1.0, x_ex, pow(x_ex,2), 
pow(x_ex,3)).t
+            y(ix,0) = y_ex
+        }
+        (X, y)
+    }
+
+    def update() : Unit = {
+        val s = pow(lda, n)
+        val R = dequeueExamples()
+        val X = R._1
+        val y = R._2
+        if (!isInitialized) {
+            Q = X.t * X
+            b = Q \ (X.t * y)
+            isInitialized = true
+        } else if (s >= eps) {
+            val Q_new = Q + (X.t * X)
+            val S = pow(lda, n) * DenseMatrix.eye[Double](chunkSize)
+            val K = inv(Q_new) * (X.t * S) // Kalman filter gain
+            val V = y - (X * b) // Innovations
+            b :+= K * V
+            Q = Q_new
+        } else {
+            isFinalized = true
+            dataQueue.clear()
+        }
+        sigma = variance(y - (X*b))
+        n += 1
+    }
+
+    def predict(batchSize: Int) : (Double,Double) = {
+        val x = DenseMatrix(1.0, batchSize, pow(batchSize,2), 
pow(batchSize,3)).reshape(1,4)
+        val y_hat = x*b
+        (max(y_hat(0,0), 0.0), sigma)
+    }
+
+}
\ No newline at end of file
diff --git a/src/main/scala/org/apache/sysml/api/ml/serving/Scheduler.scala 
b/src/main/scala/org/apache/sysml/api/ml/serving/Scheduler.scala
new file mode 100644
index 0000000..39e85c0
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/Scheduler.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+import scala.concurrent.Future
+import scala.concurrent.duration._
+import java.util.concurrent._
+import java.util.List
+
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext
+import org.apache.commons.logging.Log
+import org.apache.commons.logging.LogFactory
+
+import scala.concurrent.ExecutionContext
+
+case class SchedulingRequest(request: PredictionRequest,
+                             model: Model,
+                             latch: CountDownLatch,
+                             receivedTime: Long,
+                             var response: PredictionResponse = null,
+                             statistics: RequestStatistics = null,
+                             var memUse: Long = 0)
+
+trait Scheduler {
+    val LOG: Log = LogFactory.getLog(classOf[Scheduler].getName)
+    var executorService: ExecutorService = _
+    protected var _statistics = true
+    implicit val ec = ExecutionContext.global
+    var executorTypes = Array[String]()
+    var modelManager = ReferenceCountedModelManager
+
+    def start(numCores: Int, cpuMemoryBudgetInBytes: Long, gpus: String): Unit 
= {
+        LOG.info(s"Starting Scheduler with ${numCores} CPUs and ${gpus} GPUs")
+        var numGpus = 0
+        var gCtxs: List[GPUContext] = null
+        if (gpus != null) {
+            GPUContextPool.AVAILABLE_GPUS = gpus
+            gCtxs = GPUContextPool.getAllGPUContexts
+            numGpus = gCtxs.size
+        }
+
+        executorService = Executors.newFixedThreadPool(numCores + numGpus)
+        modelManager.setAvailableMemory((cpuMemoryBudgetInBytes*0.80).toLong)
+
+        if (numCores > 0)
+            executorTypes :+= "CPU"
+        if (numGpus > 0)
+            executorTypes :+= "GPU"
+
+        LOG.debug("STARTING SCHEDULER WITH: " + numCores + " CPU => " + 
numGpus + " GPUS")
+        for (i <- 0 until numCores) {
+            val exec = new JmlcExecutor(this, "CPU", "CPU" + i, null)
+            executorQueues.put(exec, new BatchQueue("CPU", "CPU" + i))
+            executorService.submit(exec)
+        }
+        for (i <- 0 until numGpus) {
+            val exec = new JmlcExecutor(this, "GPU","GPU" + i, gCtxs.get(i))
+            executorQueues.put(exec, new BatchQueue("GPU", "GPU" + i))
+            executorService.submit(exec)
+        }
+    }
+
+    def shutdown(): Unit = {
+        executorService.shutdown()
+    }
+
+    def schedule(executor: JmlcExecutor): Array[SchedulingRequest]
+
+    /**
+      * Registers a model with this scheduler. This should be called before 
enqueueing requests
+      * @param model Model object to be registered
+      */
+    def addModel(model: Model): Unit = {
+        modelQueues.putIfAbsent(model.name, new 
LinkedBlockingDeque[SchedulingRequest]())
+        latencyObjectives.putIfAbsent(model.name, model.latencyObjective)
+        modelManager.put(model)
+    }
+
+    /**
+      * Sets a flag indicating if detailed statistics should be gathered which 
profile the time spent
+      * in various stages of the execution pipeline
+      * @param flag Boolean flag indicating whether statistics should be 
gathered
+      */
+    def setStatistics(flag: Boolean): Unit = { _statistics = flag }
+
+    def timeout: Duration = 300.seconds
+
+    /**
+      * Method which is used to update scheduler state of execution of a 
batch. If necessary
+      * objects implementing the Scheduler trait should override this method 
and implement any logic needed
+      * to post-process execution after a batch
+      *
+      * @param model String indicating the name of the model which was just 
executed
+      * @param latency A measure of latency for this batch
+      * @param batchSize The number of examples in the batch
+      * @param execType The device type on which the batch was executed
+      */
+    def onCompleteCallback(model: String, latency: Double, batchSize: Int, 
execType: String, execTime: Long) : Unit
+
+    val requestQueue = new LinkedBlockingDeque[SchedulingRequest]()
+    val globalSchedulingQueues = new ConcurrentHashMap[String, BatchQueue]()
+    var modelQueues = new ConcurrentHashMap[String, 
BlockingQueue[SchedulingRequest]]()
+    var executorQueues = new ConcurrentHashMap[JmlcExecutor, BatchQueue]()
+    val dummyResponse = PredictionResponse(null, -1, null)
+    val latencyObjectives = new ConcurrentHashMap[String, Duration]()
+    var counter = 0
+
+    /**
+      * Enqueues a request for processing. The scheduler will read from these 
queues to determine which
+      * models to execute next
+      * @param request A PredictionRequest object containing the data for 
which a prediction is desired
+      * @param model The model object for which prediction
+      * @return
+      */
+    private[serving] def enqueue(request: PredictionRequest, model: Model): 
Future[PredictionResponse]
+}
diff --git 
a/src/main/scala/org/apache/sysml/api/ml/serving/SchedulerFactory.scala 
b/src/main/scala/org/apache/sysml/api/ml/serving/SchedulerFactory.scala
new file mode 100644
index 0000000..6fa6007
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/serving/SchedulerFactory.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml.serving
+
+object SchedulerFactory {
+  def getScheduler(schedulerType: String) : Scheduler = {
+    schedulerType match {
+      case "non-batching"   => NonBatchingScheduler
+      case "basic-batching" => BasicBatchingScheduler
+      case "locality-aware" => LocalityAwareScheduler
+    }
+  }
+}
\ No newline at end of file

Reply via email to