You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/03/16 23:48:26 UTC

[incubator-mxnet] branch master updated: [MXNET-50] Scala Inference APIs (#9678)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b95ae7c  [MXNET-50] Scala Inference APIs (#9678)
b95ae7c is described below

commit b95ae7c9c8c3cf40a12cfc729d5dadaa322ee0dd
Author: Naveen Swamy <mn...@gmail.com>
AuthorDate: Fri Mar 16 16:48:21 2018 -0700

    [MXNET-50] Scala Inference APIs (#9678)
    
    * Scala Inference APIs
    
    * fix unit tests for shape.length == layout.length in DataDesc
    
    * make ThreadPoolHandler of size 1
    
    * Rename PredictBase to Predictor
    
    * change classify output from List to IndexedSeq
    
    * modify MXNetHandler to check if the task is executing on the same thread that created the handler
    
    * add argument epoch for Predictor/Classifier
---
 .../core/src/main/scala/ml/dmlc/mxnet/IO.scala     |   4 +
 .../src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala |  31 ++--
 scala-package/examples/pom.xml                     |   6 +
 scala-package/infer/pom.xml                        |  84 +++++++++
 .../scala/ml/dmlc/mxnet/infer/Classifier.scala     | 170 +++++++++++++++++
 .../scala/ml/dmlc/mxnet/infer/MXNetHandler.scala   | 103 +++++++++++
 .../main/scala/ml/dmlc/mxnet/infer/Predictor.scala | 198 ++++++++++++++++++++
 .../main/scala/ml/dmlc/mxnet/infer/package.scala   |  22 +++
 .../infer/src/test/resources/log4j.properties      |  24 +++
 .../ml/dmlc/mxnet/infer/ClassifierSuite.scala      | 205 +++++++++++++++++++++
 .../scala/ml/dmlc/mxnet/infer/PredictorSuite.scala | 114 ++++++++++++
 scala-package/pom.xml                              |   1 +
 12 files changed, 946 insertions(+), 16 deletions(-)

diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
index 7bc936f..8426316 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
@@ -230,6 +230,10 @@ abstract class DataPack() extends Iterable[DataBatch] {
 // Named data desc description contains name, shape, type and other extended attributes.
 case class DataDesc(name: String, shape: Shape,
                     dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") {
+  require(shape.length == layout.length, ("number of dimensions in shape :%d with" +
+    " shape: %s should match the length of the layout: %d with layout: %s").
+    format(shape.length, shape.toString, layout.length, layout))
+
   override def toString(): String = {
     s"DataDesc[$name,$shape,$dtype,$layout]"
   }
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala
index ab48ef7..d747c63 100644
--- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala
@@ -22,7 +22,6 @@ import ml.dmlc.mxnet.CheckUtils._
 import ml.dmlc.mxnet.module._
 import ml.dmlc.mxnet.optimizer._
 import ml.dmlc.mxnet.io._
-
 class ModuleSuite extends FunSuite with BeforeAndAfterAll {
   test ("model dtype") {
     val dType = DType.Float16
@@ -55,9 +54,9 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
     val mod = new Module(c, IndexedSeq("b", "c", "a"), null,
       contexts = Array(Context.cpu(0), Context.cpu(1)))
     mod.bind(dataShapes = IndexedSeq(
-      DataDesc("b", Shape(5, 5)),
-      DataDesc("c", Shape(5, 5)),
-      DataDesc("a", Shape(5, 5))),
+      DataDesc("b", Shape(5, 5), layout = "NT"),
+      DataDesc("c", Shape(5, 5), layout = "NT"),
+      DataDesc("a", Shape(5, 5), layout = "NT")),
       inputsNeedGrad = true
     )
     mod.initParams()
@@ -108,14 +107,14 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
 
     // single device
     var mod = new Module(sym, IndexedSeq("data"), null)
-    mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+    mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT")))
     mod.initParams()
     mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
     mod.update()
     mod.saveCheckpoint("test", 0, saveOptStates = true)
 
     var mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true)
-    mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+    mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT")))
     mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
     assert(mod.getSymbol.toJson == mod2.getSymbol.toJson)
     mapEqu(mod.getParams._1, mod2.getParams._1)
@@ -123,14 +122,14 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
     // multi device
     mod = new Module(sym, IndexedSeq("data"), null,
       contexts = Array(Context.cpu(0), Context.cpu(1)))
-    mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+    mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT" )))
     mod.initParams()
     mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
     mod.update()
     mod.saveCheckpoint("test", 0, saveOptStates = true)
 
     mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true)
-    mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+    mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT")))
     mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
     assert(mod.getSymbol.toJson == mod2.getSymbol.toJson)
     mapEqu(mod.getParams._1, mod2.getParams._1)
@@ -143,7 +142,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
     var dShape = Shape(7, 20)
     val mod = new Module(sym, IndexedSeq("data"), null,
       contexts = Array(Context.cpu(0), Context.cpu(1)))
-    mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape)))
+    mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = "NT")))
     mod.initParams()
     mod.initOptimizer(optimizer = new SGD(learningRate = 1f))
 
@@ -156,7 +155,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
     assert(mod.getParams._1("fc_bias").toArray.forall(_ == -1f))
 
     dShape = Shape(14, 20)
-    mod.reshape(IndexedSeq(DataDesc("data", dShape)))
+    mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = "NT")))
     mod.forward(new DataBatch(
       data = IndexedSeq(NDArray.ones(dShape)),
       label = null, index = null, pad = 0))
@@ -167,8 +166,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
   }
 
   test ("module setParams") {
-    val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 2))
-    val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 2))
+    val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2))
+    val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2))
     val trainData = new NDArrayIter(
       IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label")
 
@@ -217,8 +216,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
 
   test ("monitor") {
     // data iter
-    val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 2))
-    val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 2))
+    val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2))
+    val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2))
     val trainData = new NDArrayIter(
       IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label")
 
@@ -295,8 +294,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
 
     val mod = new Module(sym, IndexedSeq("data1", "data2"))
     mod.bind(dataShapes = IndexedSeq(
-      DataDesc("data1", dShape1), DataDesc("data2", dShape2)),
-      labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape)))
+      DataDesc("data1", dShape1), DataDesc("data2", dShape2, layout = "NCHW")),
+      labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape, layout = "N")))
     )
     mod.initParams()
     mod.initOptimizer(optimizer = new SGD(learningRate = 0.01f))
diff --git a/scala-package/examples/pom.xml b/scala-package/examples/pom.xml
index 351f71f..0a9c0b0 100644
--- a/scala-package/examples/pom.xml
+++ b/scala-package/examples/pom.xml
@@ -122,6 +122,12 @@
       <scope>provided</scope>
     </dependency>
     <dependency>
+      <groupId>ml.dmlc.mxnet</groupId>
+      <artifactId>mxnet-infer</artifactId>
+      <version>1.2.0-SNAPSHOT</version>
+      <scope>provided</scope>
+    </dependency>
+    <dependency>
       <groupId>com.sksamuel.scrimage</groupId>
       <artifactId>scrimage-core_2.11</artifactId>
       <version>2.1.8</version>
diff --git a/scala-package/infer/pom.xml b/scala-package/infer/pom.xml
new file mode 100644
index 0000000..3ae8f6c
--- /dev/null
+++ b/scala-package/infer/pom.xml
@@ -0,0 +1,84 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+    <parent>
+        <artifactId>mxnet-parent_2.11</artifactId>
+        <groupId>ml.dmlc.mxnet</groupId>
+        <version>1.2.0-SNAPSHOT</version>
+    </parent>
+    <modelVersion>4.0.0</modelVersion>
+
+    <artifactId>mxnet-infer</artifactId>
+    <name>MXNet Scala Package - Inference</name>
+
+    <profiles>
+        <profile>
+            <id>osx-x86_64-cpu</id>
+            <properties>
+                <platform>osx-x86_64-cpu</platform>
+            </properties>
+        </profile>
+        <profile>
+            <id>linux-x86_64-cpu</id>
+            <properties>
+                <platform>linux-x86_64-cpu</platform>
+            </properties>
+        </profile>
+        <profile>
+            <id>linux-x86_64-gpu</id>
+            <properties>
+                <platform>linux-x86_64-gpu</platform>
+            </properties>
+        </profile>
+    </profiles>
+
+    <build>
+        <plugins>
+            <plugin>
+                <groupId>org.apache.maven.plugins</groupId>
+                <artifactId>maven-jar-plugin</artifactId>
+                <configuration>
+                    <excludes>
+                        <exclude>META-INF/*.SF</exclude>
+                        <exclude>META-INF/*.DSA</exclude>
+                        <exclude>META-INF/*.RSA</exclude>
+                    </excludes>
+                </configuration>
+            </plugin>
+            <plugin>
+                <groupId>org.apache.maven.plugins</groupId>
+                <artifactId>maven-compiler-plugin</artifactId>
+            </plugin>
+            <plugin>
+                <groupId>org.scalatest</groupId>
+                <artifactId>scalatest-maven-plugin</artifactId>
+                <configuration>
+                    <argLine>
+                        -Djava.library.path=${project.parent.basedir}/native/${platform}/target \
+                        -Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties
+                    </argLine>
+                </configuration>
+            </plugin>
+            <plugin>
+                <groupId>org.scalastyle</groupId>
+                <artifactId>scalastyle-maven-plugin</artifactId>
+            </plugin>
+        </plugins>
+    </build>
+    <dependencies>
+        <dependency>
+            <groupId>ml.dmlc.mxnet</groupId>
+            <artifactId>mxnet-core_${scala.binary.version}</artifactId>
+            <version>1.2.0-SNAPSHOT</version>
+            <scope>provided</scope>
+        </dependency>
+        <!-- https://mvnrepository.com/artifact/org.mockito/mockito-all -->
+        <dependency>
+            <groupId>org.mockito</groupId>
+            <artifactId>mockito-all</artifactId>
+            <version>1.10.19</version>
+            <scope>test</scope>
+        </dependency>
+    </dependencies>
+</project>
\ No newline at end of file
diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala
new file mode 100644
index 0000000..6eec81c
--- /dev/null
+++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet.infer
+
+import ml.dmlc.mxnet.{Context, DataDesc, NDArray}
+import java.io.File
+
+import org.slf4j.LoggerFactory
+
+import scala.io
+import scala.collection.mutable.ListBuffer
+
+trait ClassifierBase {
+
+  /**
+    * Takes an Array of Floats and returns corresponding labels, score tuples.
+    * @param input: IndexedSequence one-dimensional array of Floats.
+    * @param topK: (Optional) How many top_k(sorting will be based on the last axis)
+    *             elements to return, if not passed returns unsorted output.
+    * @return IndexedSequence of (Label, Score) tuples.
+    */
+  def classify(input: IndexedSeq[Array[Float]],
+               topK: Option[Int] = None): IndexedSeq[(String, Float)]
+
+  /**
+    * Takes a Sequence of NDArrays and returns Label, Score tuples.
+    * @param input: Indexed Sequence of NDArrays
+    * @param topK: (Optional) How many top_k(sorting will be based on the last axis)
+    *             elements to return, if not passed returns unsorted output.
+    * @return Traversable Sequence of (Label, Score) tuple
+    */
+  def classifyWithNDArray(input: IndexedSeq[NDArray],
+                          topK: Option[Int] = None): IndexedSeq[IndexedSeq[(String, Float)]]
+}
+
+/**
+  * A class for classifier tasks
+  * @param modelPathPrefix PathPrefix from where to load the symbol, parameters and synset.txt
+  *                        Example: file://model-dir/resnet-152(containing resnet-152-symbol.json
+  *                        file://model-dir/synset.txt
+  * @param inputDescriptors Descriptors defining the input node names, shape,
+  *                         layout and Type parameters
+  * @param contexts Device Contexts on which you want to run Inference, defaults to CPU.
+  * @param epoch Model epoch to load, defaults to 0.
+  */
+class Classifier(modelPathPrefix: String,
+                 protected val inputDescriptors: IndexedSeq[DataDesc],
+                 protected val contexts: Array[Context] = Context.cpu(),
+                 protected val epoch: Option[Int] = Some(0))
+  extends ClassifierBase {
+
+  private val logger = LoggerFactory.getLogger(classOf[Classifier])
+
+  protected[infer] val predictor: PredictBase = getPredictor()
+
+  protected[infer] val synsetFilePath = getSynsetFilePath(modelPathPrefix)
+
+  protected[infer] val synset = readSynsetFile(synsetFilePath)
+
+  protected[infer] val handler = MXNetHandler()
+
+  /**
+    * Takes a flat arrays as input and returns a List of (Label, tuple)
+    * @param input: IndexedSequence one-dimensional array of Floats.
+    * @param topK: (Optional) How many top_k(sorting will be based on the last axis)
+    *             elements to return, if not passed returns unsorted output.
+    * @return IndexedSequence of (Label, Score) tuples.
+    */
+  override def classify(input: IndexedSeq[Array[Float]],
+                        topK: Option[Int] = None): IndexedSeq[(String, Float)] = {
+
+    // considering only the first output
+    val predictResult = predictor.predict(input)(0)
+    var result: IndexedSeq[(String, Float)] = IndexedSeq.empty
+
+    if (topK.isDefined) {
+      val sortedIndex = predictResult.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get)
+      result = sortedIndex.map(i => (synset(i), predictResult(i))).toIndexedSeq
+    } else {
+      result = synset.zip(predictResult).toIndexedSeq
+    }
+    result
+  }
+
+  /**
+    * Takes input as NDArrays, useful when you want to perform multiple operations on
+    * the input Array or when you want to pass a batch of input.
+    * @param input: Indexed Sequence of NDArrays
+    * @param topK: (Optional) How many top_k(sorting will be based on the last axis)
+    *             elements to return, if not passed returns unsorted output.
+    * @return Traversable Sequence of (Label, Score) tuple
+    */
+  override def classifyWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int] = None)
+  : IndexedSeq[IndexedSeq[(String, Float)]] = {
+
+    // considering only the first output
+    val predictResultND: NDArray = predictor.predictWithNDArray(input)(0)
+
+    val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]()
+
+    // iterating over the individual items(batch size is in axis 0)
+    for (i <- 0 until predictResultND.shape(0)) {
+      val r = predictResultND.at(i)
+      predictResult += r.toArray
+      r.dispose()
+    }
+
+    var result: ListBuffer[IndexedSeq[(String, Float)]] =
+      ListBuffer.empty[IndexedSeq[(String, Float)]]
+
+    if (topK.isDefined) {
+      val sortedIndices = predictResult.map(r =>
+        r.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get)
+      )
+      for (i <- sortedIndices.indices) {
+        result += sortedIndices(i).map(sIndx =>
+          (synset(sIndx), predictResult(i)(sIndx))).toIndexedSeq
+      }
+    } else {
+      for (i <- predictResult.indices) {
+        result += synset.zip(predictResult(i)).toIndexedSeq
+      }
+    }
+
+    handler.execute(predictResultND.dispose())
+
+    result.toIndexedSeq
+  }
+
+  private[infer] def getSynsetFilePath(modelPathPrefix: String): String = {
+    val dirPath = modelPathPrefix.substring(0, 1 + modelPathPrefix.lastIndexOf(File.separator))
+    val d = new File(dirPath)
+    require(d.exists && d.isDirectory, "directory: %s not found".format(dirPath))
+
+    val s = new File(dirPath + "synset.txt")
+    require(s.exists() && s.isFile, "File synset.txt should exist inside modelPath: %s".format
+    (dirPath + "synset.txt"))
+
+    s.getCanonicalPath
+  }
+
+  private[infer]  def readSynsetFile(synsetFilePath: String): IndexedSeq[String] = {
+    val f = io.Source.fromFile(synsetFilePath)
+    try {
+      f.getLines().toIndexedSeq
+    } finally {
+      f.close
+    }
+  }
+
+  private[infer] def getPredictor(): PredictBase = {
+      new Predictor(modelPathPrefix, inputDescriptors, contexts, epoch)
+  }
+
+}
diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/MXNetHandler.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/MXNetHandler.scala
new file mode 100644
index 0000000..2859f83
--- /dev/null
+++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/MXNetHandler.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet.infer
+
+import java.util.concurrent._
+
+import org.slf4j.LoggerFactory
+
+private[infer] trait MXNetHandler {
+
+  def execute[T](f: => T): T
+
+  val executor: ExecutorService
+
+}
+
+private[infer] object MXNetHandlerType extends Enumeration {
+
+  type MXNetHandlerType = Value
+  val SingleThreadHandler = Value("MXNetSingleThreadHandler")
+  val OneThreadPerModelHandler = Value("MXNetOneThreadPerModelHandler")
+}
+
+private[infer] class MXNetThreadPoolHandler(numThreads: Int = 1)
+  extends MXNetHandler {
+
+  require(numThreads > 0, "numThreads should be a positive number, you passed:%d".
+    format(numThreads))
+
+  private val logger = LoggerFactory.getLogger(classOf[MXNetThreadPoolHandler])
+  private var threadCount: Int = 0
+
+  private val threadFactory = new ThreadFactory {
+
+    override def newThread(r: Runnable): Thread = new Thread(r) {
+      setName(classOf[MXNetThreadPoolHandler].getCanonicalName
+        + "-%d".format(threadCount))
+      threadCount += 1
+    }
+  }
+
+  override val executor: ExecutorService =
+    Executors.newFixedThreadPool(numThreads, threadFactory)
+
+  private val creatorThread = executor.submit(new Callable[Thread] {
+    override def call(): Thread = Thread.currentThread()
+  }).get()
+
+  override def execute[T](f: => T): T = {
+
+    if (Thread.currentThread() eq creatorThread) {
+      f
+    } else {
+
+      val task = new Callable[T] {
+        override def call(): T = {
+          logger.info("threadId: %s".format(Thread.currentThread().getId()))
+          f
+        }
+      }
+
+      val result = executor.submit(task)
+      try {
+        result.get()
+      } catch {
+        case e : InterruptedException => throw e
+        // unwrap the exception thrown by the task
+        case e1: Exception => throw e1.getCause()
+      }
+    }
+  }
+
+}
+
+private[infer] object MXNetSingleThreadHandler extends MXNetThreadPoolHandler(1) {
+
+}
+
+private[infer] object MXNetHandler {
+
+  def apply(): MXNetHandler = {
+    if (handlerType == MXNetHandlerType.OneThreadPerModelHandler) {
+      new MXNetThreadPoolHandler(1)
+    } else {
+      MXNetSingleThreadHandler
+    }
+  }
+}
diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Predictor.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Predictor.scala
new file mode 100644
index 0000000..6be3b98
--- /dev/null
+++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Predictor.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet.infer
+
+import ml.dmlc.mxnet.io.NDArrayIter
+import ml.dmlc.mxnet.{Context, DataDesc, NDArray, Shape}
+import ml.dmlc.mxnet.module.Module
+
+import scala.collection.mutable.ListBuffer
+import org.slf4j.LoggerFactory
+
+/**
+ * Base Trait for MXNet Predictor classes.
+ */
+private[infer] trait PredictBase {
+
+  /**
+   * This method will take input as IndexedSeq one dimensional arrays and creates
+   * NDArray needed for inference. The array will be reshaped based on the input descriptors.
+   * @param input: A IndexedSequence of Scala one-dimensional array, An IndexedSequence is
+   *             is needed when the model has more than one input
+   * @return IndexedSequence array of outputs.
+   */
+  def predict(input: IndexedSeq[Array[Float]]): IndexedSeq[Array[Float]]
+
+  /**
+   * Predict using NDArray as input. This method is useful when the input is a batch of data
+   * or when multiple operations on the input have to performed.
+   * Note: User is responsible for managing allocation/deallocation of NDArrays.
+   * @param input: IndexedSequence NDArrays.
+   * @return output of Predictions as NDArrays.
+   */
+  def predictWithNDArray(input: IndexedSeq[NDArray]): IndexedSeq[NDArray]
+
+}
+
+/**
+ * Implementation of predict routines.
+ *
+ * @param modelPathPrefix PathPrefix from where to load the model.
+ *                        Example: file://model-dir/resnet-152(containing resnet-152-symbol.json,
+ * @param inputDescriptors Descriptors defining the input node names, shape,
+ *                         layout and Type parameters.
+ * <p>Note: If the input Descriptors is missing batchSize('N' in layout),
+ * a batchSize of 1 is assumed for the model.
+ * </p>
+ * @param contexts Device Contexts on which you want to run Inference, defaults to CPU.
+ * @param epoch Model epoch to load, defaults to 0.
+ */
+class Predictor(modelPathPrefix: String,
+                protected val inputDescriptors: IndexedSeq[DataDesc],
+                protected val contexts: Array[Context] = Context.cpu(),
+                protected val epoch: Option[Int] = Some(0))
+                extends PredictBase {
+
+  private val logger = LoggerFactory.getLogger(classOf[Predictor])
+
+  require(inputDescriptors.head.layout.size != 0, "layout size should not be zero")
+
+  protected[infer] var batchIndex = inputDescriptors(0).layout.indexOf('N')
+  protected[infer] var batchSize = if (batchIndex != -1) inputDescriptors(0).shape(batchIndex)
+    else 1
+
+  protected[infer] var iDescriptors = inputDescriptors
+
+  inputDescriptors.foreach((f: DataDesc) => require(f.layout.indexOf('N') == batchIndex,
+    "batch size should be in the same index for all inputs"))
+
+  if (batchIndex != -1) {
+    inputDescriptors.foreach((f: DataDesc) => require(f.shape(batchIndex) == batchSize,
+      "batch size should be same for all inputs"))
+  } else {
+    // Note: this is assuming that the input needs a batch
+    logger.warn("InputDescriptor does not have batchSize, using 1 as the default batchSize")
+    iDescriptors = inputDescriptors.map((f: DataDesc) => new DataDesc(f.name,
+      Shape(1 +: f.shape.toVector), f.dtype, 'N' +: f.layout))
+    batchIndex = 1
+  }
+
+  protected[infer] val mxNetHandler = MXNetHandler()
+
+  protected[infer] val mod = loadModule()
+
+  /**
+   * This method will take input as IndexedSeq one dimensional arrays and creates
+   * NDArray needed for inference. The array will be reshaped based on the input descriptors.
+   *
+   * @param input : A IndexedSequence of Scala one-dimensional array, An IndexedSequence is
+   *              is needed when the model has more than one input
+   * @return IndexedSequence array of outputs.
+   */
+  override def predict(input: IndexedSeq[Array[Float]])
+  : IndexedSeq[Array[Float]] = {
+
+    require(input.length == inputDescriptors.length, "number of inputs provided: %d" +
+      " does not match number of inputs in inputDescriptors: %d".format(input.length,
+        inputDescriptors.length))
+
+    for((i, d) <- input.zip(inputDescriptors)) {
+      require (i.length == d.shape.product/batchSize, "number of elements:" +
+        " %d in the input does not match the shape:%s".format( i.length, d.shape.toString()))
+    }
+    var inputND: ListBuffer[NDArray] = ListBuffer.empty[NDArray]
+
+    for((i, d) <- input.zip(inputDescriptors)) {
+      val shape = d.shape.toVector.patch(from = batchIndex, patch = Vector(1), replaced = 1)
+
+      inputND += mxNetHandler.execute(NDArray.array(i, Shape(shape)))
+    }
+
+    // rebind with batchsize 1
+    if (batchSize != 1) {
+      val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name,
+        Shape(f.shape.toVector.patch(batchIndex, Vector(1), 1)), f.dtype, f.layout) )
+      mxNetHandler.execute(mod.bind(desc, forceRebind = true,
+        forTraining = false))
+    }
+
+    val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter(
+      inputND.toIndexedSeq, dataBatchSize = 1)))
+
+    val result = resultND.map((f : NDArray) => f.toArray)
+
+    mxNetHandler.execute(inputND.foreach(_.dispose))
+    mxNetHandler.execute(resultND.foreach(_.dispose))
+
+    // rebind to batchSize
+    if (batchSize != 1) {
+      mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false, forceRebind = true))
+    }
+
+    result
+  }
+
+  /**
+   * Predict using NDArray as input. This method is useful when the input is a batch of data
+   * Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
+   *
+   * @param inputBatch : IndexedSequence NDArrays.
+   * @return output of Predictions as NDArrays.
+   */
+  override def predictWithNDArray(inputBatch: IndexedSeq[NDArray]): IndexedSeq[NDArray] = {
+
+    require(inputBatch.length == inputDescriptors.length, "number of inputs provided: %d" +
+      " do not match number of inputs in inputDescriptors: %d".format(inputBatch.length,
+        inputDescriptors.length))
+
+    // Shape validation, remove this when backend throws better error messages.
+    for((i, d) <- inputBatch.zip(iDescriptors)) {
+       require(inputBatch(0).shape(batchIndex) == i.shape(batchIndex),
+         "All inputs should be of same batch size")
+      require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1),
+        "Input Data Shape: %s should match the inputDescriptor shape: %s except batchSize".format(
+          i.shape.toString, d.shape.toString))
+    }
+
+    val inputBatchSize = inputBatch(0).shape(batchIndex)
+
+    // rebind with the new batchSize
+    if (batchSize != inputBatchSize) {
+      val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name,
+        Shape(f.shape.toVector.patch(batchIndex, Vector(inputBatchSize), 1)), f.dtype, f.layout) )
+      mxNetHandler.execute(mod.bind(desc, forceRebind = true,
+        forTraining = false))
+    }
+
+    val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter(
+      inputBatch, dataBatchSize = inputBatchSize)))
+
+    if (batchSize != inputBatchSize) {
+      mxNetHandler.execute(mod.bind(iDescriptors, forceRebind = true,
+        forTraining = false))
+    }
+    resultND
+  }
+
+  private[infer] def loadModule(): Module = {
+    val mod = mxNetHandler.execute(Module.loadCheckpoint(modelPathPrefix, epoch.get,
+      contexts = contexts))
+    mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false))
+    mod
+  }
+}
diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala
new file mode 100644
index 0000000..4e99d56
--- /dev/null
+++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet
+
+package object infer {
+  private[mxnet] val handlerType = MXNetHandlerType.SingleThreadHandler
+}
diff --git a/scala-package/infer/src/test/resources/log4j.properties b/scala-package/infer/src/test/resources/log4j.properties
new file mode 100644
index 0000000..d82fd7e
--- /dev/null
+++ b/scala-package/infer/src/test/resources/log4j.properties
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# for development debugging
+log4j.rootLogger = debug, stdout
+
+log4j.appender.stdout = org.apache.log4j.ConsoleAppender
+log4j.appender.stdout.Target = System.out
+log4j.appender.stdout.layout = org.apache.log4j.PatternLayout
+log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} [%t] [%c] [%p] - %m%n
diff --git a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ClassifierSuite.scala b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ClassifierSuite.scala
new file mode 100644
index 0000000..1a2f423
--- /dev/null
+++ b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ClassifierSuite.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet.infer
+
+import java.io.File
+import java.nio.file.{Files, Paths}
+import java.util
+
+import ml.dmlc.mxnet.module.Module
+import ml.dmlc.mxnet.{Context, DataDesc, NDArray, Shape}
+import org.mockito.Matchers._
+import org.mockito.Mockito
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.slf4j.LoggerFactory
+
+import scala.io
+
+class ClassifierSuite extends FunSuite with BeforeAndAfterAll {
+
+  private val logger = LoggerFactory.getLogger(classOf[Predictor])
+
+  var modelPath = ""
+
+  var synFilePath = ""
+
+  def createTempModelFiles(): Unit = {
+    val tempDirPath = System.getProperty("java.io.tmpdir")
+    logger.info("tempDirPath: %s".format(tempDirPath))
+
+    val modelDirPath = tempDirPath + File.separator + "model"
+    val synPath = tempDirPath + File.separator + "synset.txt"
+    val synsetFile = new File(synPath)
+    synsetFile.createNewFile()
+    val lines: util.List[String] = util.Arrays.
+      asList("class1 label1", "class2 label2", "class3 label3", "class4 label4")
+    val path = Paths.get(synPath)
+    Files.write(path, lines)
+
+    this.modelPath = modelDirPath
+    this.synFilePath = synsetFile.getCanonicalPath
+    logger.info("modelPath: %s".format(this.modelPath))
+    logger.info("synFilePath: %s".format(this.synFilePath))
+  }
+
+  override def beforeAll() {
+    createTempModelFiles
+  }
+
+  override def afterAll() {
+    new File(synFilePath).delete()
+  }
+
+  class MyClassyPredictor(val modelPathPrefix: String,
+                          override val inputDescriptors: IndexedSeq[DataDesc])
+    extends Predictor(modelPathPrefix, inputDescriptors, epoch = Some(0)) {
+
+    override def loadModule(): Module = mockModule
+
+    val getIDescriptor: IndexedSeq[DataDesc] = iDescriptors
+    val getBatchSize: Int = batchSize
+    val getBatchIndex: Int = batchIndex
+
+    lazy val mockModule: Module = Mockito.mock(classOf[Module])
+  }
+
+  class MyClassifier(modelPathPrefix: String,
+                     protected override val inputDescriptors: IndexedSeq[DataDesc])
+    extends Classifier(modelPathPrefix, inputDescriptors, Context.cpu(), Some(0)) {
+
+    override def getPredictor(): MyClassyPredictor = {
+      Mockito.mock(classOf[MyClassyPredictor])
+    }
+    def getSynset(): IndexedSeq[String] = synset
+  }
+
+  test("ClassifierSuite-getSynsetFilePath") {
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+    val testClassifer = new MyClassifier(modelPath, inputDescriptor)
+
+    assertResult(this.synFilePath) {
+      testClassifer.synsetFilePath
+    }
+  }
+
+  test("ClassifierSuite-readSynsetFile") {
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+    val testClassifer = new MyClassifier(modelPath, inputDescriptor)
+
+    assertResult(io.Source.fromFile(this.synFilePath).getLines().toList) {
+      testClassifer.getSynset()
+    }
+  }
+
+  test("ClassifierSuite-flatArray-topK") {
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+    val inputData = Array.fill[Float](12)(1)
+
+    val predictResult : IndexedSeq[Array[Float]] =
+      IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f))
+
+    val testClassifier = new MyClassifier(modelPath, inputDescriptor)
+
+    Mockito.doReturn(predictResult).when(testClassifier.predictor)
+      .predict(any(classOf[IndexedSeq[Array[Float]]]))
+
+    val result: IndexedSeq[(String, Float)] = testClassifier.
+          classify(IndexedSeq(inputData), topK = Some(10))
+
+    assertResult(predictResult(0).sortBy(-_)) {
+      result.map(_._2).toArray
+    }
+
+  }
+
+  test("ClassifierSuite-flatArrayInput") {
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+    val inputData = Array.fill[Float](12)(1)
+
+    val predictResult : IndexedSeq[Array[Float]] =
+      IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f))
+
+    val testClassifier = new MyClassifier(modelPath, inputDescriptor)
+
+    Mockito.doReturn(predictResult).when(testClassifier.predictor)
+      .predict(any(classOf[IndexedSeq[Array[Float]]]))
+
+    val result: IndexedSeq[(String, Float)] = testClassifier.
+          classify(IndexedSeq(inputData))
+
+    assertResult(predictResult(0)) {
+      result.map(_._2).toArray
+    }
+  }
+
+  test("ClassifierSuite-NDArray1InputWithoutTopK") {
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+    val inputDataShape = Shape(1, 3, 2, 2)
+    val inputData = NDArray.ones(inputDataShape)
+    val predictResult: IndexedSeq[Array[Float]] =
+      IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f))
+
+    val predictResultND: NDArray = NDArray.array(predictResult.flatten.toArray, Shape(1, 4))
+
+    val testClassifier = new MyClassifier(modelPath, inputDescriptor)
+
+    Mockito.doReturn(IndexedSeq(predictResultND)).when(testClassifier.predictor)
+      .predictWithNDArray(any(classOf[IndexedSeq[NDArray]]))
+
+    val result: IndexedSeq[IndexedSeq[(String, Float)]] = testClassifier.
+          classifyWithNDArray(IndexedSeq(inputData))
+
+    assert(predictResult.size == result.size)
+
+    for(i <- predictResult.indices) {
+      assertResult(predictResult(i)) {
+        result(i).map(_._2).toArray
+      }
+    }
+  }
+
+  test("ClassifierSuite-NDArray3InputWithTopK") {
+
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+    val inputDataShape = Shape(3, 3, 2, 2)
+    val inputData = NDArray.ones(inputDataShape)
+
+    val predictResult: IndexedSeq[Array[Float]] =
+      IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f),
+        Array(.98f, 0.97f, 0.96f, 0.99f), Array(.98f, 0.97f, 0.96f, 0.99f))
+
+    val predictResultND: NDArray = NDArray.array(predictResult.flatten.toArray, Shape(3, 4))
+
+    val testClassifier = new MyClassifier(modelPath, inputDescriptor)
+
+    Mockito.doReturn(IndexedSeq(predictResultND)).when(testClassifier.predictor)
+      .predictWithNDArray(any(classOf[IndexedSeq[NDArray]]))
+
+    val result: IndexedSeq[IndexedSeq[(String, Float)]] = testClassifier.
+          classifyWithNDArray(IndexedSeq(inputData), topK = Some(10))
+
+    assert(predictResult.size == result.size)
+
+    for(i <- predictResult.indices) {
+      assertResult(predictResult(i).sortBy(-_)) {
+        result(i).map(_._2).toArray
+      }
+    }
+  }
+
+}
diff --git a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala
new file mode 100644
index 0000000..da4d965
--- /dev/null
+++ b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet.infer
+
+
+import ml.dmlc.mxnet.io.NDArrayIter
+import ml.dmlc.mxnet.module.{BaseModule, Module}
+import ml.dmlc.mxnet.{DataDesc, NDArray, Shape}
+import org.mockito.Matchers._
+import org.mockito.Mockito
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+class PredictorSuite extends FunSuite with BeforeAndAfterAll {
+
+  class MyPredictor(val modelPathPrefix: String,
+                    override val inputDescriptors: IndexedSeq[DataDesc])
+    extends Predictor(modelPathPrefix, inputDescriptors, epoch = Some(0)) {
+
+    override def loadModule(): Module = mockModule
+
+    val getIDescriptor: IndexedSeq[DataDesc] = iDescriptors
+    val getBatchSize: Int = batchSize
+    val getBatchIndex: Int = batchIndex
+
+    lazy val mockModule: Module = Mockito.mock(classOf[Module])
+  }
+
+  test("PredictorSuite-testPredictorConstruction") {
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2)))
+
+    val mockPredictor = new MyPredictor("xyz", inputDescriptor)
+
+    assert(mockPredictor.getBatchSize == 1)
+    assert(mockPredictor.getBatchIndex == inputDescriptor(0).layout.indexOf('N'))
+
+    val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2)),
+      new DataDesc("data", Shape(2, 3, 2, 2)))
+
+    assertThrows[IllegalArgumentException] {
+      val mockPredictor = new MyPredictor("xyz", inputDescriptor2)
+    }
+
+    // batchsize is defaulted to 1
+    val iDesc2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(3, 2, 2), layout = "CHW"))
+    val p2 = new MyPredictor("xyz", inputDescriptor)
+    assert(p2.getBatchSize == 1, "should use a default batch size of 1")
+
+  }
+
+  test("PredictorSuite-testWithFlatArrays") {
+
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+    val inputData = Array.fill[Float](12)(1)
+
+    // this will disposed at the end of the predict call on Predictor.
+    val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2)))
+
+    val testPredictor = new MyPredictor("xyz", inputDescriptor)
+
+    Mockito.doReturn(predictResult).when(testPredictor.mockModule)
+      .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean])
+
+    val testFun = testPredictor.predict(IndexedSeq(inputData))
+
+    assert(testFun.size == 1, "output size should be 1 ")
+
+    assert(Array.fill[Float](12)(1).mkString == testFun(0).mkString)
+
+    // Verify that the module was bound with batch size 1 and rebound back to the original
+    // input descriptor. the number of times is twice here because loadModule overrides the
+    // initial bind.
+    Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]],
+      any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean]
+      , any[Option[BaseModule]], any[String])
+  }
+
+  test("PredictorSuite-testWithNDArray") {
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+    val inputData = NDArray.ones(Shape(1, 3, 2, 2))
+
+    // this will disposed at the end of the predict call on Predictor.
+    val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2)))
+
+    val testPredictor = new MyPredictor("xyz", inputDescriptor)
+
+    Mockito.doReturn(predictResult).when(testPredictor.mockModule)
+      .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean])
+
+    val testFun = testPredictor.predictWithNDArray(IndexedSeq(inputData))
+
+    assert(testFun.size == 1, "output size should be 1")
+
+    assert(Array.fill[Float](12)(1).mkString == testFun(0).toArray.mkString)
+
+    Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]],
+      any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean]
+      , any[Option[BaseModule]], any[String])
+  }
+}
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 02bcd86..27dfe2f 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -37,6 +37,7 @@
     <module>macros</module>
     <module>core</module>
     <module>native</module>
+    <module>infer</module>
     <module>examples</module>
     <module>spark</module>
     <module>assembly</module>

-- 
To stop receiving notification emails like this one, please contact
nswamy@apache.org.