You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/03/26 23:27:22 UTC
[incubator-mxnet] branch master updated: adding context parameter
to infer api- imageclassifier and objectdetector (#10252)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 777089d adding context parameter to infer api- imageclassifier and objectdetector (#10252)
777089d is described below
commit 777089dd771735aa2c8efb4ae088a4a68ce896a4
Author: Roshani Nagmote <ro...@gmail.com>
AuthorDate: Mon Mar 26 16:27:17 2018 -0700
adding context parameter to infer api- imageclassifier and objectdetector (#10252)
* adding context parameter
* parameter description added
---
.../ml/dmlc/mxnet/infer/ImageClassifier.scala | 18 +++++++----
.../scala/ml/dmlc/mxnet/infer/ObjectDetector.scala | 37 ++++++++++++++--------
.../ml/dmlc/mxnet/infer/ImageClassifierSuite.scala | 26 ++++++++-------
.../ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala | 11 ++++---
4 files changed, 56 insertions(+), 36 deletions(-)
diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ImageClassifier.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ImageClassifier.scala
index 45c4e76..070b0bf 100644
--- a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ImageClassifier.scala
+++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ImageClassifier.scala
@@ -17,7 +17,7 @@
package ml.dmlc.mxnet.infer
-import ml.dmlc.mxnet.{DataDesc, NDArray, Shape}
+import ml.dmlc.mxnet.{Context, DataDesc, NDArray, Shape}
import scala.collection.mutable.ListBuffer
@@ -37,13 +37,15 @@ import javax.imageio.ImageIO
* file://model-dir/synset.txt
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and Type parameters
+ * @param contexts Device Contexts on which you want to run Inference, defaults to CPU.
+ * @param epoch Model epoch to load, defaults to 0.
*/
class ImageClassifier(modelPathPrefix: String,
- inputDescriptors: IndexedSeq[DataDesc])
+ inputDescriptors: IndexedSeq[DataDesc],
+ contexts: Array[Context] = Context.cpu(),
+ epoch: Option[Int] = Some(0))
extends Classifier(modelPathPrefix,
- inputDescriptors) {
-
- val classifier: Classifier = getClassifier(modelPathPrefix, inputDescriptors)
+ inputDescriptors, contexts, epoch) {
protected[infer] val inputLayout = inputDescriptors.head.layout
@@ -108,8 +110,10 @@ class ImageClassifier(modelPathPrefix: String,
result
}
- def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]): Classifier = {
- new Classifier(modelPathPrefix, inputDescriptors)
+ def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
+ contexts: Array[Context] = Context.cpu(),
+ epoch: Option[Int] = Some(0)): Classifier = {
+ new Classifier(modelPathPrefix, inputDescriptors, contexts, epoch)
}
}
diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ObjectDetector.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ObjectDetector.scala
index 2d83caf..30e1432 100644
--- a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ObjectDetector.scala
+++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ObjectDetector.scala
@@ -16,12 +16,14 @@
*/
package ml.dmlc.mxnet.infer
+
// scalastyle:off
import java.awt.image.BufferedImage
// scalastyle:on
-import ml.dmlc.mxnet.NDArray
-import ml.dmlc.mxnet.DataDesc
+
+import ml.dmlc.mxnet.{Context, DataDesc, NDArray}
import scala.collection.mutable.ListBuffer
+
/**
* A class for object detection tasks
*
@@ -32,11 +34,16 @@ import scala.collection.mutable.ListBuffer
* file://model-dir/synset.txt
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and Type parameters
+ * @param contexts Device Contexts on which you want to run Inference, defaults to CPU.
+ * @param epoch Model epoch to load, defaults to 0.
*/
class ObjectDetector(modelPathPrefix: String,
- inputDescriptors: IndexedSeq[DataDesc]) {
+ inputDescriptors: IndexedSeq[DataDesc],
+ contexts: Array[Context] = Context.cpu(),
+ epoch: Option[Int] = Some(0)) {
- val imgClassifier: ImageClassifier = getImageClassifier(modelPathPrefix, inputDescriptors)
+ val imgClassifier: ImageClassifier =
+ getImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)
val inputShape = imgClassifier.inputShape
@@ -54,7 +61,7 @@ class ObjectDetector(modelPathPrefix: String,
* To Detect bounding boxes and corresponding labels
*
* @param inputImage : PathPrefix of the input image
- * @param topK : Get top k elements with maximum probability
+ * @param topK : Get top k elements with maximum probability
* @return List of List of tuples of (class, [probability, xmin, ymin, xmax, ymax])
*/
def imageObjectDetect(inputImage: BufferedImage,
@@ -71,9 +78,10 @@ class ObjectDetector(modelPathPrefix: String,
/**
* Takes input images as NDArrays. Useful when you want to perform multiple operations on
* the input Array, or when you want to pass a batch of input images.
+ *
* @param input : Indexed Sequence of NDArrays
- * @param topK : (Optional) How many top_k(sorting will be based on the last axis)
- * elements to return. If not passed, returns all unsorted output.
+ * @param topK : (Optional) How many top_k(sorting will be based on the last axis)
+ * elements to return. If not passed, returns all unsorted output.
* @return List of List of tuples of (class, [probability, xmin, ymin, xmax, ymax])
*/
def objectDetectWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int])
@@ -90,10 +98,10 @@ class ObjectDetector(modelPathPrefix: String,
batchResult.toIndexedSeq
}
- private def sortAndReformat(predictResultND : NDArray, topK: Option[Int])
+ private def sortAndReformat(predictResultND: NDArray, topK: Option[Int])
: IndexedSeq[(String, Array[Float])] = {
val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]()
- val accuracy : ListBuffer[Float] = ListBuffer[Float]()
+ val accuracy: ListBuffer[Float] = ListBuffer[Float]()
// iterating over the all the predictions
val length = predictResultND.shape(0)
@@ -110,7 +118,7 @@ class ObjectDetector(modelPathPrefix: String,
handler.execute(r.dispose())
}
var result = IndexedSeq[(String, Array[Float])]()
- if(topK.isDefined) {
+ if (topK.isDefined) {
var sortedIndices = accuracy.zipWithIndex.sortBy(-_._1).map(_._2)
sortedIndices = sortedIndices.take(topK.get)
// takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax
@@ -127,8 +135,9 @@ class ObjectDetector(modelPathPrefix: String,
/**
* To classify batch of input images according to the provided model
+ *
* @param inputBatch Input batch of Buffered images
- * @param topK Get top k elements with maximum probability
+ * @param topK Get top k elements with maximum probability
* @return List of list of tuples of (class, probability)
*/
def imageBatchObjectDetect(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None):
@@ -148,9 +157,11 @@ class ObjectDetector(modelPathPrefix: String,
result
}
- def getImageClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]):
+ def getImageClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
+ contexts: Array[Context] = Context.cpu(),
+ epoch: Option[Int] = Some(0)):
ImageClassifier = {
- new ImageClassifier(modelPathPrefix, inputDescriptors)
+ new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)
}
}
diff --git a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ImageClassifierSuite.scala b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ImageClassifierSuite.scala
index 96fc800..85059be 100644
--- a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ImageClassifierSuite.scala
+++ b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ImageClassifierSuite.scala
@@ -17,11 +17,10 @@
package ml.dmlc.mxnet.infer
-import ml.dmlc.mxnet.{DType, DataDesc, Shape, NDArray}
-
+import ml.dmlc.mxnet._
import org.mockito.Matchers._
import org.mockito.Mockito
-import org.scalatest.{BeforeAndAfterAll}
+import org.scalatest.BeforeAndAfterAll
// scalastyle:off
import java.awt.image.BufferedImage
@@ -33,7 +32,7 @@ import java.awt.image.BufferedImage
class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
class MyImageClassifier(modelPathPrefix: String,
- inputDescriptors: IndexedSeq[DataDesc])
+ inputDescriptors: IndexedSeq[DataDesc])
extends ImageClassifier(modelPathPrefix, inputDescriptors) {
override def getPredictor(): MyClassyPredictor = {
@@ -41,7 +40,8 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
}
override def getClassifier(modelPathPrefix: String, inputDescriptors:
- IndexedSeq[DataDesc]): Classifier = {
+ IndexedSeq[DataDesc], contexts: Array[Context] = Context.cpu(),
+ epoch: Option[Int] = Some(0)): Classifier = {
Mockito.mock(classOf[Classifier])
}
@@ -84,7 +84,7 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
val synset = testImageClassifier.synset
- val predictExpectedOp : List[(String, Float)] =
+ val predictExpectedOp: List[(String, Float)] =
List[(String, Float)]((synset(1), .98f), (synset(2), .97f),
(synset(3), .96f), (synset(0), .99f))
@@ -93,13 +93,14 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
Mockito.doReturn(IndexedSeq(predictExpectedND)).when(testImageClassifier.predictor)
.predictWithNDArray(any(classOf[IndexedSeq[NDArray]]))
- Mockito.doReturn(IndexedSeq(predictExpectedOp)).when(testImageClassifier.classifier)
+ Mockito.doReturn(IndexedSeq(predictExpectedOp))
+ .when(testImageClassifier.getClassifier(modelPath, inputDescriptor))
.classifyWithNDArray(any(classOf[IndexedSeq[NDArray]]), Some(anyInt()))
val predictResult: IndexedSeq[IndexedSeq[(String, Float)]] =
testImageClassifier.classifyImage(inputImage, Some(4))
- for(i <- predictExpected.indices) {
+ for (i <- predictExpected.indices) {
assertResult(predictExpected(i).sortBy(-_)) {
predictResult(i).map(_._2).toArray
}
@@ -119,15 +120,15 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
val predictExpected: IndexedSeq[Array[Array[Float]]] =
IndexedSeq[Array[Array[Float]]](Array(Array(.98f, 0.97f, 0.96f, 0.99f),
- Array(.98f, 0.97f, 0.96f, 0.99f)))
+ Array(.98f, 0.97f, 0.96f, 0.99f)))
val synset = testImageClassifier.synset
- val predictExpectedOp : List[List[(String, Float)]] =
+ val predictExpectedOp: List[List[(String, Float)]] =
List[List[(String, Float)]](List((synset(1), .98f), (synset(2), .97f),
(synset(3), .96f), (synset(0), .99f)),
List((synset(1), .98f), (synset(2), .97f),
- (synset(3), .96f), (synset(0), .99f)))
+ (synset(3), .96f), (synset(0), .99f)))
val predictExpectedND: NDArray = NDArray.array(predictExpected.flatten.flatten.toArray,
Shape(2, 4))
@@ -135,7 +136,8 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
Mockito.doReturn(IndexedSeq(predictExpectedND)).when(testImageClassifier.predictor)
.predictWithNDArray(any(classOf[IndexedSeq[NDArray]]))
- Mockito.doReturn(IndexedSeq(predictExpectedOp)).when(testImageClassifier.classifier)
+ Mockito.doReturn(IndexedSeq(predictExpectedOp))
+ .when(testImageClassifier.getClassifier(modelPath, inputDescriptor))
.classifyWithNDArray(any(classOf[IndexedSeq[NDArray]]), Some(anyInt()))
val result: IndexedSeq[IndexedSeq[(String, Float)]] =
diff --git a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala
index a691aa3..5e6f32f 100644
--- a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala
+++ b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala
@@ -23,7 +23,7 @@ import java.awt.image.BufferedImage
// scalastyle:on
import ml.dmlc.mxnet.Context
import ml.dmlc.mxnet.DataDesc
-import ml.dmlc.mxnet.{NDArray, Shape}
+import ml.dmlc.mxnet.{Context, NDArray, Shape}
import org.mockito.Matchers.any
import org.mockito.Mockito
import org.scalatest.BeforeAndAfterAll
@@ -36,7 +36,8 @@ class ObjectDetectorSuite extends ClassifierSuite with BeforeAndAfterAll {
extends ObjectDetector(modelPathPrefix, inputDescriptors) {
override def getImageClassifier(modelPathPrefix: String, inputDescriptors:
- IndexedSeq[DataDesc]): ImageClassifier = {
+ IndexedSeq[DataDesc], contexts: Array[Context] = Context.cpu(),
+ epoch: Option[Int] = Some(0)): ImageClassifier = {
new MyImageClassifier(modelPathPrefix, inputDescriptors)
}
@@ -44,13 +45,15 @@ class ObjectDetectorSuite extends ClassifierSuite with BeforeAndAfterAll {
class MyImageClassifier(modelPathPrefix: String,
protected override val inputDescriptors: IndexedSeq[DataDesc])
- extends ImageClassifier(modelPathPrefix, inputDescriptors) {
+ extends ImageClassifier(modelPathPrefix, inputDescriptors, Context.cpu(), Some(0)) {
override def getPredictor(): MyClassyPredictor = {
Mockito.mock(classOf[MyClassyPredictor])
}
- override def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]):
+ override def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
+ contexts: Array[Context] = Context.cpu(),
+ epoch: Option[Int] = Some(0)):
Classifier = {
new MyClassifier(modelPathPrefix, inputDescriptors)
}
--
To stop receiving notification emails like this one, please contact
nswamy@apache.org.