You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/03/26 23:27:22 UTC

[incubator-mxnet] branch master updated: adding context parameter to infer api- imageclassifier and objectdetector (#10252)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 777089d  adding context parameter to infer api- imageclassifier and objectdetector (#10252)
777089d is described below

commit 777089dd771735aa2c8efb4ae088a4a68ce896a4
Author: Roshani Nagmote <ro...@gmail.com>
AuthorDate: Mon Mar 26 16:27:17 2018 -0700

    adding context parameter to infer api- imageclassifier and objectdetector (#10252)
    
    * adding context parameter
    
    * parameter description added
---
 .../ml/dmlc/mxnet/infer/ImageClassifier.scala      | 18 +++++++----
 .../scala/ml/dmlc/mxnet/infer/ObjectDetector.scala | 37 ++++++++++++++--------
 .../ml/dmlc/mxnet/infer/ImageClassifierSuite.scala | 26 ++++++++-------
 .../ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala  | 11 ++++---
 4 files changed, 56 insertions(+), 36 deletions(-)

diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ImageClassifier.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ImageClassifier.scala
index 45c4e76..070b0bf 100644
--- a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ImageClassifier.scala
+++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ImageClassifier.scala
@@ -17,7 +17,7 @@
 
 package ml.dmlc.mxnet.infer
 
-import ml.dmlc.mxnet.{DataDesc, NDArray, Shape}
+import ml.dmlc.mxnet.{Context, DataDesc, NDArray, Shape}
 
 import scala.collection.mutable.ListBuffer
 
@@ -37,13 +37,15 @@ import javax.imageio.ImageIO
   *                         file://model-dir/synset.txt
   * @param inputDescriptors Descriptors defining the input node names, shape,
   *                         layout and Type parameters
+  * @param contexts Device Contexts on which you want to run Inference, defaults to CPU.
+  * @param epoch Model epoch to load, defaults to 0.
   */
 class ImageClassifier(modelPathPrefix: String,
-                      inputDescriptors: IndexedSeq[DataDesc])
+                      inputDescriptors: IndexedSeq[DataDesc],
+                      contexts: Array[Context] = Context.cpu(),
+                      epoch: Option[Int] = Some(0))
                       extends Classifier(modelPathPrefix,
-                      inputDescriptors) {
-
-  val classifier: Classifier = getClassifier(modelPathPrefix, inputDescriptors)
+                      inputDescriptors, contexts, epoch) {
 
   protected[infer] val inputLayout = inputDescriptors.head.layout
 
@@ -108,8 +110,10 @@ class ImageClassifier(modelPathPrefix: String,
     result
   }
 
-  def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]): Classifier = {
-    new Classifier(modelPathPrefix, inputDescriptors)
+  def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
+                    contexts: Array[Context] = Context.cpu(),
+                    epoch: Option[Int] = Some(0)): Classifier = {
+    new Classifier(modelPathPrefix, inputDescriptors, contexts, epoch)
   }
 }
 
diff --git a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ObjectDetector.scala b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ObjectDetector.scala
index 2d83caf..30e1432 100644
--- a/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ObjectDetector.scala
+++ b/scala-package/infer/src/main/scala/ml/dmlc/mxnet/infer/ObjectDetector.scala
@@ -16,12 +16,14 @@
  */
 
 package ml.dmlc.mxnet.infer
+
 // scalastyle:off
 import java.awt.image.BufferedImage
 // scalastyle:on
-import ml.dmlc.mxnet.NDArray
-import ml.dmlc.mxnet.DataDesc
+
+import ml.dmlc.mxnet.{Context, DataDesc, NDArray}
 import scala.collection.mutable.ListBuffer
+
 /**
   * A class for object detection tasks
   *
@@ -32,11 +34,16 @@ import scala.collection.mutable.ListBuffer
   *                         file://model-dir/synset.txt
   * @param inputDescriptors Descriptors defining the input node names, shape,
   *                         layout and Type parameters
+  * @param contexts Device Contexts on which you want to run Inference, defaults to CPU.
+  * @param epoch Model epoch to load, defaults to 0.
   */
 class ObjectDetector(modelPathPrefix: String,
-                     inputDescriptors: IndexedSeq[DataDesc]) {
+                     inputDescriptors: IndexedSeq[DataDesc],
+                     contexts: Array[Context] = Context.cpu(),
+                     epoch: Option[Int] = Some(0)) {
 
-  val imgClassifier: ImageClassifier = getImageClassifier(modelPathPrefix, inputDescriptors)
+  val imgClassifier: ImageClassifier =
+    getImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)
 
   val inputShape = imgClassifier.inputShape
 
@@ -54,7 +61,7 @@ class ObjectDetector(modelPathPrefix: String,
     * To Detect bounding boxes and corresponding labels
     *
     * @param inputImage : PathPrefix of the input image
-    * @param topK : Get top k elements with maximum probability
+    * @param topK       : Get top k elements with maximum probability
     * @return List of List of tuples of (class, [probability, xmin, ymin, xmax, ymax])
     */
   def imageObjectDetect(inputImage: BufferedImage,
@@ -71,9 +78,10 @@ class ObjectDetector(modelPathPrefix: String,
   /**
     * Takes input images as NDArrays. Useful when you want to perform multiple operations on
     * the input Array, or when you want to pass a batch of input images.
+    *
     * @param input : Indexed Sequence of NDArrays
-    * @param topK : (Optional) How many top_k(sorting will be based on the last axis)
-    *             elements to return. If not passed, returns all unsorted output.
+    * @param topK  : (Optional) How many top_k(sorting will be based on the last axis)
+    *              elements to return. If not passed, returns all unsorted output.
     * @return List of List of tuples of (class, [probability, xmin, ymin, xmax, ymax])
     */
   def objectDetectWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int])
@@ -90,10 +98,10 @@ class ObjectDetector(modelPathPrefix: String,
     batchResult.toIndexedSeq
   }
 
-  private def sortAndReformat(predictResultND : NDArray, topK: Option[Int])
+  private def sortAndReformat(predictResultND: NDArray, topK: Option[Int])
   : IndexedSeq[(String, Array[Float])] = {
     val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]()
-    val accuracy : ListBuffer[Float] = ListBuffer[Float]()
+    val accuracy: ListBuffer[Float] = ListBuffer[Float]()
 
     // iterating over the all the predictions
     val length = predictResultND.shape(0)
@@ -110,7 +118,7 @@ class ObjectDetector(modelPathPrefix: String,
       handler.execute(r.dispose())
     }
     var result = IndexedSeq[(String, Array[Float])]()
-    if(topK.isDefined) {
+    if (topK.isDefined) {
       var sortedIndices = accuracy.zipWithIndex.sortBy(-_._1).map(_._2)
       sortedIndices = sortedIndices.take(topK.get)
       // takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax
@@ -127,8 +135,9 @@ class ObjectDetector(modelPathPrefix: String,
 
   /**
     * To classify batch of input images according to the provided model
+    *
     * @param inputBatch Input batch of Buffered images
-    * @param topK Get top k elements with maximum probability
+    * @param topK       Get top k elements with maximum probability
     * @return List of list of tuples of (class, probability)
     */
   def imageBatchObjectDetect(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None):
@@ -148,9 +157,11 @@ class ObjectDetector(modelPathPrefix: String,
     result
   }
 
-  def getImageClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]):
+  def getImageClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
+                         contexts: Array[Context] = Context.cpu(),
+                         epoch: Option[Int] = Some(0)):
   ImageClassifier = {
-    new ImageClassifier(modelPathPrefix, inputDescriptors)
+    new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)
   }
 
 }
diff --git a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ImageClassifierSuite.scala b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ImageClassifierSuite.scala
index 96fc800..85059be 100644
--- a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ImageClassifierSuite.scala
+++ b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ImageClassifierSuite.scala
@@ -17,11 +17,10 @@
 
 package ml.dmlc.mxnet.infer
 
-import ml.dmlc.mxnet.{DType, DataDesc, Shape, NDArray}
-
+import ml.dmlc.mxnet._
 import org.mockito.Matchers._
 import org.mockito.Mockito
-import org.scalatest.{BeforeAndAfterAll}
+import org.scalatest.BeforeAndAfterAll
 
 // scalastyle:off
 import java.awt.image.BufferedImage
@@ -33,7 +32,7 @@ import java.awt.image.BufferedImage
 class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
 
   class MyImageClassifier(modelPathPrefix: String,
-                           inputDescriptors: IndexedSeq[DataDesc])
+                          inputDescriptors: IndexedSeq[DataDesc])
     extends ImageClassifier(modelPathPrefix, inputDescriptors) {
 
     override def getPredictor(): MyClassyPredictor = {
@@ -41,7 +40,8 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
     }
 
     override def getClassifier(modelPathPrefix: String, inputDescriptors:
-    IndexedSeq[DataDesc]): Classifier = {
+    IndexedSeq[DataDesc], contexts: Array[Context] = Context.cpu(),
+                               epoch: Option[Int] = Some(0)): Classifier = {
       Mockito.mock(classOf[Classifier])
     }
 
@@ -84,7 +84,7 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
 
     val synset = testImageClassifier.synset
 
-    val predictExpectedOp : List[(String, Float)] =
+    val predictExpectedOp: List[(String, Float)] =
       List[(String, Float)]((synset(1), .98f), (synset(2), .97f),
         (synset(3), .96f), (synset(0), .99f))
 
@@ -93,13 +93,14 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
     Mockito.doReturn(IndexedSeq(predictExpectedND)).when(testImageClassifier.predictor)
       .predictWithNDArray(any(classOf[IndexedSeq[NDArray]]))
 
-    Mockito.doReturn(IndexedSeq(predictExpectedOp)).when(testImageClassifier.classifier)
+    Mockito.doReturn(IndexedSeq(predictExpectedOp))
+      .when(testImageClassifier.getClassifier(modelPath, inputDescriptor))
       .classifyWithNDArray(any(classOf[IndexedSeq[NDArray]]), Some(anyInt()))
 
     val predictResult: IndexedSeq[IndexedSeq[(String, Float)]] =
       testImageClassifier.classifyImage(inputImage, Some(4))
 
-    for(i <- predictExpected.indices) {
+    for (i <- predictExpected.indices) {
       assertResult(predictExpected(i).sortBy(-_)) {
         predictResult(i).map(_._2).toArray
       }
@@ -119,15 +120,15 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
 
     val predictExpected: IndexedSeq[Array[Array[Float]]] =
       IndexedSeq[Array[Array[Float]]](Array(Array(.98f, 0.97f, 0.96f, 0.99f),
-            Array(.98f, 0.97f, 0.96f, 0.99f)))
+        Array(.98f, 0.97f, 0.96f, 0.99f)))
 
     val synset = testImageClassifier.synset
 
-    val predictExpectedOp : List[List[(String, Float)]] =
+    val predictExpectedOp: List[List[(String, Float)]] =
       List[List[(String, Float)]](List((synset(1), .98f), (synset(2), .97f),
         (synset(3), .96f), (synset(0), .99f)),
         List((synset(1), .98f), (synset(2), .97f),
-        (synset(3), .96f), (synset(0), .99f)))
+          (synset(3), .96f), (synset(0), .99f)))
 
     val predictExpectedND: NDArray = NDArray.array(predictExpected.flatten.flatten.toArray,
       Shape(2, 4))
@@ -135,7 +136,8 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
     Mockito.doReturn(IndexedSeq(predictExpectedND)).when(testImageClassifier.predictor)
       .predictWithNDArray(any(classOf[IndexedSeq[NDArray]]))
 
-    Mockito.doReturn(IndexedSeq(predictExpectedOp)).when(testImageClassifier.classifier)
+    Mockito.doReturn(IndexedSeq(predictExpectedOp))
+      .when(testImageClassifier.getClassifier(modelPath, inputDescriptor))
       .classifyWithNDArray(any(classOf[IndexedSeq[NDArray]]), Some(anyInt()))
 
     val result: IndexedSeq[IndexedSeq[(String, Float)]] =
diff --git a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala
index a691aa3..5e6f32f 100644
--- a/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala
+++ b/scala-package/infer/src/test/scala/ml/dmlc/mxnet/infer/ObjectDetectorSuite.scala
@@ -23,7 +23,7 @@ import java.awt.image.BufferedImage
 // scalastyle:on
 import ml.dmlc.mxnet.Context
 import ml.dmlc.mxnet.DataDesc
-import ml.dmlc.mxnet.{NDArray, Shape}
+import ml.dmlc.mxnet.{Context, NDArray, Shape}
 import org.mockito.Matchers.any
 import org.mockito.Mockito
 import org.scalatest.BeforeAndAfterAll
@@ -36,7 +36,8 @@ class ObjectDetectorSuite extends ClassifierSuite with BeforeAndAfterAll {
     extends ObjectDetector(modelPathPrefix, inputDescriptors) {
 
     override def getImageClassifier(modelPathPrefix: String, inputDescriptors:
-    IndexedSeq[DataDesc]): ImageClassifier = {
+        IndexedSeq[DataDesc], contexts: Array[Context] = Context.cpu(),
+        epoch: Option[Int] = Some(0)): ImageClassifier = {
       new MyImageClassifier(modelPathPrefix, inputDescriptors)
     }
 
@@ -44,13 +45,15 @@ class ObjectDetectorSuite extends ClassifierSuite with BeforeAndAfterAll {
 
   class MyImageClassifier(modelPathPrefix: String,
                      protected override val inputDescriptors: IndexedSeq[DataDesc])
-    extends ImageClassifier(modelPathPrefix, inputDescriptors) {
+    extends ImageClassifier(modelPathPrefix, inputDescriptors, Context.cpu(), Some(0)) {
 
     override def getPredictor(): MyClassyPredictor = {
       Mockito.mock(classOf[MyClassyPredictor])
     }
 
-    override def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]):
+    override def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
+                               contexts: Array[Context] = Context.cpu(),
+                               epoch: Option[Int] = Some(0)):
     Classifier = {
       new MyClassifier(modelPathPrefix, inputDescriptors)
     }

-- 
To stop receiving notification emails like this one, please contact
nswamy@apache.org.