You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/03/16 23:48:26 UTC
[incubator-mxnet] branch master updated: [MXNET-50] Scala Inference
APIs (#9678)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new b95ae7c [MXNET-50] Scala Inference APIs (#9678)
b95ae7c is described below
commit b95ae7c9c8c3cf40a12cfc729d5dadaa322ee0dd
Author: Naveen Swamy <mn...@gmail.com>
AuthorDate: Fri Mar 16 16:48:21 2018 -0700
[MXNET-50] Scala Inference APIs (#9678)
* Scala Inference APIs
* fix unit tests for shape.length == layout.length in DataDesc
* make ThreadPoolHandler of size 1
* Rename PredictBase to Predictor
* change classify output from List to IndexedSeq
* modify MXNetHandler to check if the task is executing on the same thread that created the handler
* add argument epoch for Predictor/Classifier
---
.../core/src/main/scala/ml/dmlc/mxnet/IO.scala | 4 +
.../src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala | 31 ++--
scala-package/examples/pom.xml | 6 +
scala-package/infer/pom.xml | 84 +++++++++
.../scala/ml/dmlc/mxnet/infer/Classifier.scala | 170 +++++++++++++++++
.../scala/ml/dmlc/mxnet/infer/MXNetHandler.scala | 103 +++++++++++
.../main/scala/ml/dmlc/mxnet/infer/Predictor.scala | 198 ++++++++++++++++++++
.../main/scala/ml/dmlc/mxnet/infer/package.scala | 22 +++
.../infer/src/test/resources/log4j.properties | 24 +++
.../ml/dmlc/mxnet/infer/ClassifierSuite.scala | 205 +++++++++++++++++++++
.../scala/ml/dmlc/mxnet/infer/PredictorSuite.scala | 114 ++++++++++++
scala-package/pom.xml | 1 +
12 files changed, 946 insertions(+), 16 deletions(-)
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
index 7bc936f..8426316 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
@@ -230,6 +230,10 @@ abstract class DataPack() extends Iterable[DataBatch] {
// Named data desc description contains name, shape, type and other extended attributes.
case class DataDesc(name: String, shape: Shape,
dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") {
+ require(shape.length == layout.length, ("number of dimensions in shape :%d with" +
+ " shape: %s should match the length of the layout: %d with layout: %s").
+ format(shape.length, shape.toString, layout.length, layout))
+
override def toString(): String = {
s"DataDesc[$name,$shape,$dtype,$layout]"
}
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala
index ab48ef7..d747c63 100644
--- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala
@@ -22,7 +22,6 @@ import ml.dmlc.mxnet.CheckUtils._
import ml.dmlc.mxnet.module._
import ml.dmlc.mxnet.optimizer._
import ml.dmlc.mxnet.io._
-
class ModuleSuite extends FunSuite with BeforeAndAfterAll {
test ("model dtype") {
val dType = DType.Float16
@@ -55,9 +54,9 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
val mod = new Module(c, IndexedSeq("b", "c", "a"), null,
contexts = Array(Context.cpu(0), Context.cpu(1)))
mod.bind(dataShapes = IndexedSeq(
- DataDesc("b", Shape(5, 5)),
- DataDesc("c", Shape(5, 5)),
- DataDesc("a", Shape(5, 5))),
+ DataDesc("b", Shape(5, 5), layout = "NT"),
+ DataDesc("c", Shape(5, 5), layout = "NT"),
+ DataDesc("a", Shape(5, 5), layout = "NT")),
inputsNeedGrad = true
)
mod.initParams()
@@ -108,14 +107,14 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
// single device
var mod = new Module(sym, IndexedSeq("data"), null)
- mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+ mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT")))
mod.initParams()
mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
mod.update()
mod.saveCheckpoint("test", 0, saveOptStates = true)
var mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true)
- mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+ mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT")))
mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
assert(mod.getSymbol.toJson == mod2.getSymbol.toJson)
mapEqu(mod.getParams._1, mod2.getParams._1)
@@ -123,14 +122,14 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
// multi device
mod = new Module(sym, IndexedSeq("data"), null,
contexts = Array(Context.cpu(0), Context.cpu(1)))
- mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+ mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT" )))
mod.initParams()
mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
mod.update()
mod.saveCheckpoint("test", 0, saveOptStates = true)
mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true)
- mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+ mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = "NT")))
mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
assert(mod.getSymbol.toJson == mod2.getSymbol.toJson)
mapEqu(mod.getParams._1, mod2.getParams._1)
@@ -143,7 +142,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
var dShape = Shape(7, 20)
val mod = new Module(sym, IndexedSeq("data"), null,
contexts = Array(Context.cpu(0), Context.cpu(1)))
- mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape)))
+ mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = "NT")))
mod.initParams()
mod.initOptimizer(optimizer = new SGD(learningRate = 1f))
@@ -156,7 +155,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
assert(mod.getParams._1("fc_bias").toArray.forall(_ == -1f))
dShape = Shape(14, 20)
- mod.reshape(IndexedSeq(DataDesc("data", dShape)))
+ mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = "NT")))
mod.forward(new DataBatch(
data = IndexedSeq(NDArray.ones(dShape)),
label = null, index = null, pad = 0))
@@ -167,8 +166,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
}
test ("module setParams") {
- val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 2))
- val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 2))
+ val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2))
+ val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2))
val trainData = new NDArrayIter(
IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label")
@@ -217,8 +216,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
test ("monitor") {
// data iter
- val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 2))
- val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 2))
+ val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2))
+ val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2))
val trainData = new NDArrayIter(
IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label")
@@ -295,8 +294,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
val mod = new Module(sym, IndexedSeq("data1", "data2"))
mod.bind(dataShapes = IndexedSeq(
- DataDesc("data1", dShape1), DataDesc("data2", dShape2)),
- labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape)))
+ DataDesc("data1", dShape1), DataDesc("data2", dShape2, layout = "NCHW")),
+ labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape, layout = "N")))
)
mod.initParams()
mod.initOptimizer(optimizer = new SGD(learningRate = 0.01f))
diff --git a/scala-package/examples/pom.xml b/scala-package/examples/pom.xml
index 351f71f..0a9c0b0 100644
--- a/scala-package/examples/pom.xml
+++ b/scala-package/examples/pom.xml
@@ -122,6 +122,12 @@
<scope>provided</scope>
</dependency>
<dependency>
+ <groupId>ml.dmlc.mxnet</groupId>
+ <artifactId>mxnet-infer</artifactId>
+ <version>1.2.0-SNAPSHOT</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
<groupId>com.sksamuel.scrimage</groupId>
<artifactId>scrimage-core_2.11</artifactId>
<version>2.1.8</version>
diff --git a/scala-package/infer/pom.xml b/scala-package/infer/pom.xml
new file mode 100644
index 0000000..3ae8f6c
--- /dev/null
+++ b/scala-package/infer/pom.xml
@@ -0,0 +1,84 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <parent>
+ <artifactId>mxnet-parent_2.11</artifactId>
+ <groupId>ml.dmlc.mxnet</groupId>
+ <version>1.2.0-SNAPSHOT</version>
+ </parent>
+ <modelVersion>4.0.0</modelVersion>
+
+ <artifactId>mxnet-infer</artifactId>
+ <name>MXNet Scala Package - Inference</name>
+
+ <profiles>
+ <profile>
+ <id>osx-x86_64-cpu</id>
+ <properties>
+ <platform>osx-x86_64-cpu</platform>
+ </properties>
+ </profile>
+ <profile>
+ <id>linux-x86_64-cpu</id>
+ <properties>
+ <platform>linux-x86_64-cpu</platform>
+ </properties>
+ </profile>
+ <profile>
+ <id>linux-x86_64-gpu</id>
+ <properties>
+ <platform>linux-x86_64-gpu</platform>
+ </properties>
+ </profile>
+ </profiles>
+
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <excludes>
+ <exclude>META-INF/*.SF</exclude>
+ <exclude>META-INF/*.DSA</exclude>
+ <exclude>META-INF/*.RSA</exclude>
+ </excludes>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ </plugin>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ <configuration>
+ <argLine>
+ -Djava.library.path=${project.parent.basedir}/native/${platform}/target \
+ -Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties
+ </argLine>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.scalastyle</groupId>
+ <artifactId>scalastyle-maven-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ <dependencies>
+ <dependency>
+ <groupId>ml.dmlc.mxnet</groupId>
+ <artifactId>mxnet-core_${scala.binary.version}</artifactId>
+ <version>1.2.0-SNAPSHOT</version>
+ <scope>provided</scope>
+ </dependency>
+ <!-- https://mvnrepository.com/artifact/org.mockito/mockito-all -->
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <version>1.10.19</version>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+</project>
\ No newline at end of file
diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala
new file mode 100644
index 0000000..6eec81c
--- /dev/null
+++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Classifier.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet.infer
+
+import ml.dmlc.mxnet.{Context, DataDesc, NDArray}
+import java.io.File
+
+import org.slf4j.LoggerFactory
+
+import scala.io
+import scala.collection.mutable.ListBuffer
+
+trait ClassifierBase {
+
+ /**
+ * Takes an Array of Floats and returns corresponding labels, score tuples.
+ * @param input: IndexedSequence one-dimensional array of Floats.
+ * @param topK: (Optional) How many top_k(sorting will be based on the last axis)
+ * elements to return, if not passed returns unsorted output.
+ * @return IndexedSequence of (Label, Score) tuples.
+ */
+ def classify(input: IndexedSeq[Array[Float]],
+ topK: Option[Int] = None): IndexedSeq[(String, Float)]
+
+ /**
+ * Takes a Sequence of NDArrays and returns Label, Score tuples.
+ * @param input: Indexed Sequence of NDArrays
+ * @param topK: (Optional) How many top_k(sorting will be based on the last axis)
+ * elements to return, if not passed returns unsorted output.
+ * @return Traversable Sequence of (Label, Score) tuple
+ */
+ def classifyWithNDArray(input: IndexedSeq[NDArray],
+ topK: Option[Int] = None): IndexedSeq[IndexedSeq[(String, Float)]]
+}
+
+/**
+ * A class for classifier tasks
+ * @param modelPathPrefix PathPrefix from where to load the symbol, parameters and synset.txt
+ * Example: file://model-dir/resnet-152(containing resnet-152-symbol.json
+ * file://model-dir/synset.txt
+ * @param inputDescriptors Descriptors defining the input node names, shape,
+ * layout and Type parameters
+ * @param contexts Device Contexts on which you want to run Inference, defaults to CPU.
+ * @param epoch Model epoch to load, defaults to 0.
+ */
+class Classifier(modelPathPrefix: String,
+ protected val inputDescriptors: IndexedSeq[DataDesc],
+ protected val contexts: Array[Context] = Context.cpu(),
+ protected val epoch: Option[Int] = Some(0))
+ extends ClassifierBase {
+
+ private val logger = LoggerFactory.getLogger(classOf[Classifier])
+
+ protected[infer] val predictor: PredictBase = getPredictor()
+
+ protected[infer] val synsetFilePath = getSynsetFilePath(modelPathPrefix)
+
+ protected[infer] val synset = readSynsetFile(synsetFilePath)
+
+ protected[infer] val handler = MXNetHandler()
+
+ /**
+ * Takes a flat arrays as input and returns a List of (Label, tuple)
+ * @param input: IndexedSequence one-dimensional array of Floats.
+ * @param topK: (Optional) How many top_k(sorting will be based on the last axis)
+ * elements to return, if not passed returns unsorted output.
+ * @return IndexedSequence of (Label, Score) tuples.
+ */
+ override def classify(input: IndexedSeq[Array[Float]],
+ topK: Option[Int] = None): IndexedSeq[(String, Float)] = {
+
+ // considering only the first output
+ val predictResult = predictor.predict(input)(0)
+ var result: IndexedSeq[(String, Float)] = IndexedSeq.empty
+
+ if (topK.isDefined) {
+ val sortedIndex = predictResult.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get)
+ result = sortedIndex.map(i => (synset(i), predictResult(i))).toIndexedSeq
+ } else {
+ result = synset.zip(predictResult).toIndexedSeq
+ }
+ result
+ }
+
+ /**
+ * Takes input as NDArrays, useful when you want to perform multiple operations on
+ * the input Array or when you want to pass a batch of input.
+ * @param input: Indexed Sequence of NDArrays
+ * @param topK: (Optional) How many top_k(sorting will be based on the last axis)
+ * elements to return, if not passed returns unsorted output.
+ * @return Traversable Sequence of (Label, Score) tuple
+ */
+ override def classifyWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int] = None)
+ : IndexedSeq[IndexedSeq[(String, Float)]] = {
+
+ // considering only the first output
+ val predictResultND: NDArray = predictor.predictWithNDArray(input)(0)
+
+ val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]()
+
+ // iterating over the individual items(batch size is in axis 0)
+ for (i <- 0 until predictResultND.shape(0)) {
+ val r = predictResultND.at(i)
+ predictResult += r.toArray
+ r.dispose()
+ }
+
+ var result: ListBuffer[IndexedSeq[(String, Float)]] =
+ ListBuffer.empty[IndexedSeq[(String, Float)]]
+
+ if (topK.isDefined) {
+ val sortedIndices = predictResult.map(r =>
+ r.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get)
+ )
+ for (i <- sortedIndices.indices) {
+ result += sortedIndices(i).map(sIndx =>
+ (synset(sIndx), predictResult(i)(sIndx))).toIndexedSeq
+ }
+ } else {
+ for (i <- predictResult.indices) {
+ result += synset.zip(predictResult(i)).toIndexedSeq
+ }
+ }
+
+ handler.execute(predictResultND.dispose())
+
+ result.toIndexedSeq
+ }
+
+ private[infer] def getSynsetFilePath(modelPathPrefix: String): String = {
+ val dirPath = modelPathPrefix.substring(0, 1 + modelPathPrefix.lastIndexOf(File.separator))
+ val d = new File(dirPath)
+ require(d.exists && d.isDirectory, "directory: %s not found".format(dirPath))
+
+ val s = new File(dirPath + "synset.txt")
+ require(s.exists() && s.isFile, "File synset.txt should exist inside modelPath: %s".format
+ (dirPath + "synset.txt"))
+
+ s.getCanonicalPath
+ }
+
+ private[infer] def readSynsetFile(synsetFilePath: String): IndexedSeq[String] = {
+ val f = io.Source.fromFile(synsetFilePath)
+ try {
+ f.getLines().toIndexedSeq
+ } finally {
+ f.close
+ }
+ }
+
+ private[infer] def getPredictor(): PredictBase = {
+ new Predictor(modelPathPrefix, inputDescriptors, contexts, epoch)
+ }
+
+}
diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/MXNetHandler.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/MXNetHandler.scala
new file mode 100644
index 0000000..2859f83
--- /dev/null
+++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/MXNetHandler.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet.infer
+
+import java.util.concurrent._
+
+import org.slf4j.LoggerFactory
+
+private[infer] trait MXNetHandler {
+
+ def execute[T](f: => T): T
+
+ val executor: ExecutorService
+
+}
+
+private[infer] object MXNetHandlerType extends Enumeration {
+
+ type MXNetHandlerType = Value
+ val SingleThreadHandler = Value("MXNetSingleThreadHandler")
+ val OneThreadPerModelHandler = Value("MXNetOneThreadPerModelHandler")
+}
+
+private[infer] class MXNetThreadPoolHandler(numThreads: Int = 1)
+ extends MXNetHandler {
+
+ require(numThreads > 0, "numThreads should be a positive number, you passed:%d".
+ format(numThreads))
+
+ private val logger = LoggerFactory.getLogger(classOf[MXNetThreadPoolHandler])
+ private var threadCount: Int = 0
+
+ private val threadFactory = new ThreadFactory {
+
+ override def newThread(r: Runnable): Thread = new Thread(r) {
+ setName(classOf[MXNetThreadPoolHandler].getCanonicalName
+ + "-%d".format(threadCount))
+ threadCount += 1
+ }
+ }
+
+ override val executor: ExecutorService =
+ Executors.newFixedThreadPool(numThreads, threadFactory)
+
+ private val creatorThread = executor.submit(new Callable[Thread] {
+ override def call(): Thread = Thread.currentThread()
+ }).get()
+
+ override def execute[T](f: => T): T = {
+
+ if (Thread.currentThread() eq creatorThread) {
+ f
+ } else {
+
+ val task = new Callable[T] {
+ override def call(): T = {
+ logger.info("threadId: %s".format(Thread.currentThread().getId()))
+ f
+ }
+ }
+
+ val result = executor.submit(task)
+ try {
+ result.get()
+ } catch {
+ case e : InterruptedException => throw e
+ // unwrap the exception thrown by the task
+ case e1: Exception => throw e1.getCause()
+ }
+ }
+ }
+
+}
+
+private[infer] object MXNetSingleThreadHandler extends MXNetThreadPoolHandler(1) {
+
+}
+
+private[infer] object MXNetHandler {
+
+ def apply(): MXNetHandler = {
+ if (handlerType == MXNetHandlerType.OneThreadPerModelHandler) {
+ new MXNetThreadPoolHandler(1)
+ } else {
+ MXNetSingleThreadHandler
+ }
+ }
+}
diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Predictor.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Predictor.scala
new file mode 100644
index 0000000..6be3b98
--- /dev/null
+++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/Predictor.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet.infer
+
+import ml.dmlc.mxnet.io.NDArrayIter
+import ml.dmlc.mxnet.{Context, DataDesc, NDArray, Shape}
+import ml.dmlc.mxnet.module.Module
+
+import scala.collection.mutable.ListBuffer
+import org.slf4j.LoggerFactory
+
+/**
+ * Base Trait for MXNet Predictor classes.
+ */
+private[infer] trait PredictBase {
+
+ /**
+ * This method will take input as IndexedSeq one dimensional arrays and creates
+ * NDArray needed for inference. The array will be reshaped based on the input descriptors.
+ * @param input: A IndexedSequence of Scala one-dimensional array, An IndexedSequence is
+ * is needed when the model has more than one input
+ * @return IndexedSequence array of outputs.
+ */
+ def predict(input: IndexedSeq[Array[Float]]): IndexedSeq[Array[Float]]
+
+ /**
+ * Predict using NDArray as input. This method is useful when the input is a batch of data
+ * or when multiple operations on the input have to performed.
+ * Note: User is responsible for managing allocation/deallocation of NDArrays.
+ * @param input: IndexedSequence NDArrays.
+ * @return output of Predictions as NDArrays.
+ */
+ def predictWithNDArray(input: IndexedSeq[NDArray]): IndexedSeq[NDArray]
+
+}
+
+/**
+ * Implementation of predict routines.
+ *
+ * @param modelPathPrefix PathPrefix from where to load the model.
+ * Example: file://model-dir/resnet-152(containing resnet-152-symbol.json,
+ * @param inputDescriptors Descriptors defining the input node names, shape,
+ * layout and Type parameters.
+ * <p>Note: If the input Descriptors is missing batchSize('N' in layout),
+ * a batchSize of 1 is assumed for the model.
+ * </p>
+ * @param contexts Device Contexts on which you want to run Inference, defaults to CPU.
+ * @param epoch Model epoch to load, defaults to 0.
+ */
+class Predictor(modelPathPrefix: String,
+ protected val inputDescriptors: IndexedSeq[DataDesc],
+ protected val contexts: Array[Context] = Context.cpu(),
+ protected val epoch: Option[Int] = Some(0))
+ extends PredictBase {
+
+ private val logger = LoggerFactory.getLogger(classOf[Predictor])
+
+ require(inputDescriptors.head.layout.size != 0, "layout size should not be zero")
+
+ protected[infer] var batchIndex = inputDescriptors(0).layout.indexOf('N')
+ protected[infer] var batchSize = if (batchIndex != -1) inputDescriptors(0).shape(batchIndex)
+ else 1
+
+ protected[infer] var iDescriptors = inputDescriptors
+
+ inputDescriptors.foreach((f: DataDesc) => require(f.layout.indexOf('N') == batchIndex,
+ "batch size should be in the same index for all inputs"))
+
+ if (batchIndex != -1) {
+ inputDescriptors.foreach((f: DataDesc) => require(f.shape(batchIndex) == batchSize,
+ "batch size should be same for all inputs"))
+ } else {
+ // Note: this is assuming that the input needs a batch
+ logger.warn("InputDescriptor does not have batchSize, using 1 as the default batchSize")
+ iDescriptors = inputDescriptors.map((f: DataDesc) => new DataDesc(f.name,
+ Shape(1 +: f.shape.toVector), f.dtype, 'N' +: f.layout))
+ batchIndex = 1
+ }
+
+ protected[infer] val mxNetHandler = MXNetHandler()
+
+ protected[infer] val mod = loadModule()
+
+ /**
+ * This method will take input as IndexedSeq one dimensional arrays and creates
+ * NDArray needed for inference. The array will be reshaped based on the input descriptors.
+ *
+ * @param input : A IndexedSequence of Scala one-dimensional array, An IndexedSequence is
+ * is needed when the model has more than one input
+ * @return IndexedSequence array of outputs.
+ */
+ override def predict(input: IndexedSeq[Array[Float]])
+ : IndexedSeq[Array[Float]] = {
+
+ require(input.length == inputDescriptors.length, "number of inputs provided: %d" +
+ " does not match number of inputs in inputDescriptors: %d".format(input.length,
+ inputDescriptors.length))
+
+ for((i, d) <- input.zip(inputDescriptors)) {
+ require (i.length == d.shape.product/batchSize, "number of elements:" +
+ " %d in the input does not match the shape:%s".format( i.length, d.shape.toString()))
+ }
+ var inputND: ListBuffer[NDArray] = ListBuffer.empty[NDArray]
+
+ for((i, d) <- input.zip(inputDescriptors)) {
+ val shape = d.shape.toVector.patch(from = batchIndex, patch = Vector(1), replaced = 1)
+
+ inputND += mxNetHandler.execute(NDArray.array(i, Shape(shape)))
+ }
+
+ // rebind with batchsize 1
+ if (batchSize != 1) {
+ val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name,
+ Shape(f.shape.toVector.patch(batchIndex, Vector(1), 1)), f.dtype, f.layout) )
+ mxNetHandler.execute(mod.bind(desc, forceRebind = true,
+ forTraining = false))
+ }
+
+ val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter(
+ inputND.toIndexedSeq, dataBatchSize = 1)))
+
+ val result = resultND.map((f : NDArray) => f.toArray)
+
+ mxNetHandler.execute(inputND.foreach(_.dispose))
+ mxNetHandler.execute(resultND.foreach(_.dispose))
+
+ // rebind to batchSize
+ if (batchSize != 1) {
+ mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false, forceRebind = true))
+ }
+
+ result
+ }
+
+ /**
+ * Predict using NDArray as input. This method is useful when the input is a batch of data
+ * Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
+ *
+ * @param inputBatch : IndexedSequence NDArrays.
+ * @return output of Predictions as NDArrays.
+ */
+ override def predictWithNDArray(inputBatch: IndexedSeq[NDArray]): IndexedSeq[NDArray] = {
+
+ require(inputBatch.length == inputDescriptors.length, "number of inputs provided: %d" +
+ " do not match number of inputs in inputDescriptors: %d".format(inputBatch.length,
+ inputDescriptors.length))
+
+ // Shape validation, remove this when backend throws better error messages.
+ for((i, d) <- inputBatch.zip(iDescriptors)) {
+ require(inputBatch(0).shape(batchIndex) == i.shape(batchIndex),
+ "All inputs should be of same batch size")
+ require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1),
+ "Input Data Shape: %s should match the inputDescriptor shape: %s except batchSize".format(
+ i.shape.toString, d.shape.toString))
+ }
+
+ val inputBatchSize = inputBatch(0).shape(batchIndex)
+
+ // rebind with the new batchSize
+ if (batchSize != inputBatchSize) {
+ val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name,
+ Shape(f.shape.toVector.patch(batchIndex, Vector(inputBatchSize), 1)), f.dtype, f.layout) )
+ mxNetHandler.execute(mod.bind(desc, forceRebind = true,
+ forTraining = false))
+ }
+
+ val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter(
+ inputBatch, dataBatchSize = inputBatchSize)))
+
+ if (batchSize != inputBatchSize) {
+ mxNetHandler.execute(mod.bind(iDescriptors, forceRebind = true,
+ forTraining = false))
+ }
+ resultND
+ }
+
+ private[infer] def loadModule(): Module = {
+ val mod = mxNetHandler.execute(Module.loadCheckpoint(modelPathPrefix, epoch.get,
+ contexts = contexts))
+ mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false))
+ mod
+ }
+}
diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala
new file mode 100644
index 0000000..4e99d56
--- /dev/null
+++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/package.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet
+
+package object infer {
+ private[mxnet] val handlerType = MXNetHandlerType.SingleThreadHandler
+}
diff --git a/scala-package/infer/src/test/resources/log4j.properties b/scala-package/infer/src/test/resources/log4j.properties
new file mode 100644
index 0000000..d82fd7e
--- /dev/null
+++ b/scala-package/infer/src/test/resources/log4j.properties
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# for development debugging
+log4j.rootLogger = debug, stdout
+
+log4j.appender.stdout = org.apache.log4j.ConsoleAppender
+log4j.appender.stdout.Target = System.out
+log4j.appender.stdout.layout = org.apache.log4j.PatternLayout
+log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} [%t] [%c] [%p] - %m%n
diff --git a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ClassifierSuite.scala b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ClassifierSuite.scala
new file mode 100644
index 0000000..1a2f423
--- /dev/null
+++ b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ClassifierSuite.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet.infer
+
+import java.io.File
+import java.nio.file.{Files, Paths}
+import java.util
+
+import ml.dmlc.mxnet.module.Module
+import ml.dmlc.mxnet.{Context, DataDesc, NDArray, Shape}
+import org.mockito.Matchers._
+import org.mockito.Mockito
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.slf4j.LoggerFactory
+
+import scala.io
+
+class ClassifierSuite extends FunSuite with BeforeAndAfterAll {
+
+ private val logger = LoggerFactory.getLogger(classOf[Predictor])
+
+ var modelPath = ""
+
+ var synFilePath = ""
+
+ def createTempModelFiles(): Unit = {
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ logger.info("tempDirPath: %s".format(tempDirPath))
+
+ val modelDirPath = tempDirPath + File.separator + "model"
+ val synPath = tempDirPath + File.separator + "synset.txt"
+ val synsetFile = new File(synPath)
+ synsetFile.createNewFile()
+ val lines: util.List[String] = util.Arrays.
+ asList("class1 label1", "class2 label2", "class3 label3", "class4 label4")
+ val path = Paths.get(synPath)
+ Files.write(path, lines)
+
+ this.modelPath = modelDirPath
+ this.synFilePath = synsetFile.getCanonicalPath
+ logger.info("modelPath: %s".format(this.modelPath))
+ logger.info("synFilePath: %s".format(this.synFilePath))
+ }
+
+ override def beforeAll() {
+ createTempModelFiles
+ }
+
+ override def afterAll() {
+ new File(synFilePath).delete()
+ }
+
+ class MyClassyPredictor(val modelPathPrefix: String,
+ override val inputDescriptors: IndexedSeq[DataDesc])
+ extends Predictor(modelPathPrefix, inputDescriptors, epoch = Some(0)) {
+
+ override def loadModule(): Module = mockModule
+
+ val getIDescriptor: IndexedSeq[DataDesc] = iDescriptors
+ val getBatchSize: Int = batchSize
+ val getBatchIndex: Int = batchIndex
+
+ lazy val mockModule: Module = Mockito.mock(classOf[Module])
+ }
+
+ class MyClassifier(modelPathPrefix: String,
+ protected override val inputDescriptors: IndexedSeq[DataDesc])
+ extends Classifier(modelPathPrefix, inputDescriptors, Context.cpu(), Some(0)) {
+
+ override def getPredictor(): MyClassyPredictor = {
+ Mockito.mock(classOf[MyClassyPredictor])
+ }
+ def getSynset(): IndexedSeq[String] = synset
+ }
+
+ test("ClassifierSuite-getSynsetFilePath") {
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+ val testClassifer = new MyClassifier(modelPath, inputDescriptor)
+
+ assertResult(this.synFilePath) {
+ testClassifer.synsetFilePath
+ }
+ }
+
+ test("ClassifierSuite-readSynsetFile") {
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+ val testClassifer = new MyClassifier(modelPath, inputDescriptor)
+
+ assertResult(io.Source.fromFile(this.synFilePath).getLines().toList) {
+ testClassifer.getSynset()
+ }
+ }
+
+ test("ClassifierSuite-flatArray-topK") {
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+ val inputData = Array.fill[Float](12)(1)
+
+ val predictResult : IndexedSeq[Array[Float]] =
+ IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f))
+
+ val testClassifier = new MyClassifier(modelPath, inputDescriptor)
+
+ Mockito.doReturn(predictResult).when(testClassifier.predictor)
+ .predict(any(classOf[IndexedSeq[Array[Float]]]))
+
+ val result: IndexedSeq[(String, Float)] = testClassifier.
+ classify(IndexedSeq(inputData), topK = Some(10))
+
+ assertResult(predictResult(0).sortBy(-_)) {
+ result.map(_._2).toArray
+ }
+
+ }
+
+ test("ClassifierSuite-flatArrayInput") {
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+ val inputData = Array.fill[Float](12)(1)
+
+ val predictResult : IndexedSeq[Array[Float]] =
+ IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f))
+
+ val testClassifier = new MyClassifier(modelPath, inputDescriptor)
+
+ Mockito.doReturn(predictResult).when(testClassifier.predictor)
+ .predict(any(classOf[IndexedSeq[Array[Float]]]))
+
+ val result: IndexedSeq[(String, Float)] = testClassifier.
+ classify(IndexedSeq(inputData))
+
+ assertResult(predictResult(0)) {
+ result.map(_._2).toArray
+ }
+ }
+
+ test("ClassifierSuite-NDArray1InputWithoutTopK") {
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+ val inputDataShape = Shape(1, 3, 2, 2)
+ val inputData = NDArray.ones(inputDataShape)
+ val predictResult: IndexedSeq[Array[Float]] =
+ IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f))
+
+ val predictResultND: NDArray = NDArray.array(predictResult.flatten.toArray, Shape(1, 4))
+
+ val testClassifier = new MyClassifier(modelPath, inputDescriptor)
+
+ Mockito.doReturn(IndexedSeq(predictResultND)).when(testClassifier.predictor)
+ .predictWithNDArray(any(classOf[IndexedSeq[NDArray]]))
+
+ val result: IndexedSeq[IndexedSeq[(String, Float)]] = testClassifier.
+ classifyWithNDArray(IndexedSeq(inputData))
+
+ assert(predictResult.size == result.size)
+
+ for(i <- predictResult.indices) {
+ assertResult(predictResult(i)) {
+ result(i).map(_._2).toArray
+ }
+ }
+ }
+
+ test("ClassifierSuite-NDArray3InputWithTopK") {
+
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+ val inputDataShape = Shape(3, 3, 2, 2)
+ val inputData = NDArray.ones(inputDataShape)
+
+ val predictResult: IndexedSeq[Array[Float]] =
+ IndexedSeq[Array[Float]](Array(.98f, 0.97f, 0.96f, 0.99f),
+ Array(.98f, 0.97f, 0.96f, 0.99f), Array(.98f, 0.97f, 0.96f, 0.99f))
+
+ val predictResultND: NDArray = NDArray.array(predictResult.flatten.toArray, Shape(3, 4))
+
+ val testClassifier = new MyClassifier(modelPath, inputDescriptor)
+
+ Mockito.doReturn(IndexedSeq(predictResultND)).when(testClassifier.predictor)
+ .predictWithNDArray(any(classOf[IndexedSeq[NDArray]]))
+
+ val result: IndexedSeq[IndexedSeq[(String, Float)]] = testClassifier.
+ classifyWithNDArray(IndexedSeq(inputData), topK = Some(10))
+
+ assert(predictResult.size == result.size)
+
+ for(i <- predictResult.indices) {
+ assertResult(predictResult(i).sortBy(-_)) {
+ result(i).map(_._2).toArray
+ }
+ }
+ }
+
+}
diff --git a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala
new file mode 100644
index 0000000..da4d965
--- /dev/null
+++ b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/PredictorSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet.infer
+
+
+import ml.dmlc.mxnet.io.NDArrayIter
+import ml.dmlc.mxnet.module.{BaseModule, Module}
+import ml.dmlc.mxnet.{DataDesc, NDArray, Shape}
+import org.mockito.Matchers._
+import org.mockito.Mockito
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+class PredictorSuite extends FunSuite with BeforeAndAfterAll {
+
+ class MyPredictor(val modelPathPrefix: String,
+ override val inputDescriptors: IndexedSeq[DataDesc])
+ extends Predictor(modelPathPrefix, inputDescriptors, epoch = Some(0)) {
+
+ override def loadModule(): Module = mockModule
+
+ val getIDescriptor: IndexedSeq[DataDesc] = iDescriptors
+ val getBatchSize: Int = batchSize
+ val getBatchIndex: Int = batchIndex
+
+ lazy val mockModule: Module = Mockito.mock(classOf[Module])
+ }
+
+ test("PredictorSuite-testPredictorConstruction") {
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2)))
+
+ val mockPredictor = new MyPredictor("xyz", inputDescriptor)
+
+ assert(mockPredictor.getBatchSize == 1)
+ assert(mockPredictor.getBatchIndex == inputDescriptor(0).layout.indexOf('N'))
+
+ val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2)),
+ new DataDesc("data", Shape(2, 3, 2, 2)))
+
+ assertThrows[IllegalArgumentException] {
+ val mockPredictor = new MyPredictor("xyz", inputDescriptor2)
+ }
+
+ // batchsize is defaulted to 1
+ val iDesc2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(3, 2, 2), layout = "CHW"))
+ val p2 = new MyPredictor("xyz", inputDescriptor)
+ assert(p2.getBatchSize == 1, "should use a default batch size of 1")
+
+ }
+
+ test("PredictorSuite-testWithFlatArrays") {
+
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+ val inputData = Array.fill[Float](12)(1)
+
+ // this will disposed at the end of the predict call on Predictor.
+ val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2)))
+
+ val testPredictor = new MyPredictor("xyz", inputDescriptor)
+
+ Mockito.doReturn(predictResult).when(testPredictor.mockModule)
+ .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean])
+
+ val testFun = testPredictor.predict(IndexedSeq(inputData))
+
+ assert(testFun.size == 1, "output size should be 1 ")
+
+ assert(Array.fill[Float](12)(1).mkString == testFun(0).mkString)
+
+ // Verify that the module was bound with batch size 1 and rebound back to the original
+ // input descriptor. the number of times is twice here because loadModule overrides the
+ // initial bind.
+ Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]],
+ any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean]
+ , any[Option[BaseModule]], any[String])
+ }
+
+ test("PredictorSuite-testWithNDArray") {
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+ val inputData = NDArray.ones(Shape(1, 3, 2, 2))
+
+ // this will disposed at the end of the predict call on Predictor.
+ val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2)))
+
+ val testPredictor = new MyPredictor("xyz", inputDescriptor)
+
+ Mockito.doReturn(predictResult).when(testPredictor.mockModule)
+ .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean])
+
+ val testFun = testPredictor.predictWithNDArray(IndexedSeq(inputData))
+
+ assert(testFun.size == 1, "output size should be 1")
+
+ assert(Array.fill[Float](12)(1).mkString == testFun(0).toArray.mkString)
+
+ Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]],
+ any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean]
+ , any[Option[BaseModule]], any[String])
+ }
+}
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 02bcd86..27dfe2f 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -37,6 +37,7 @@
<module>macros</module>
<module>core</module>
<module>native</module>
+ <module>infer</module>
<module>examples</module>
<module>spark</module>
<module>assembly</module>
--
To stop receiving notification emails like this one, please contact
nswamy@apache.org.