You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2019/01/24 19:37:11 UTC

[incubator-mxnet] branch master updated: [MXNET-1293] Adding Iterables instead of List to method signature for infer APIs in Java (#13977)

This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 24412df  [MXNET-1293] Adding Iterables instead of List to method signature for infer APIs in Java (#13977)
24412df is described below

commit 24412df84934b57f6e5ed7ac135bc6fe5402cff2
Author: Piyush Ghai <gh...@osu.edu>
AuthorDate: Thu Jan 24 11:36:55 2019 -0800

    [MXNET-1293] Adding Iterables instead of List to method signature for infer APIs in Java (#13977)
    
    * Added Iterables as input type instead of List in Predictor for Java
    
    * Added Iterables to ObjectDetector API
    
    * Added tests for Predictor API
    
    * Added tests for ObjectDetector
---
 .../mxnet/infer/javaapi/ObjectDetector.scala       | 10 ++++----
 .../org/apache/mxnet/infer/javaapi/Predictor.scala | 12 ++++-----
 .../mxnet/infer/javaapi/ObjectDetectorTest.java    | 25 +++++++++++++++++++
 .../apache/mxnet/infer/javaapi/PredictorTest.java  | 29 +++++++++++++++++++---
 4 files changed, 62 insertions(+), 14 deletions(-)

diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
index 3014f8d..05334e4 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
@@ -44,8 +44,8 @@ import scala.language.implicitConversions
   */
 class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.ObjectDetector){
 
-  def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc], contexts:
-  java.util.List[Context], epoch: Int)
+  def this(modelPathPrefix: String, inputDescriptors: java.lang.Iterable[DataDesc], contexts:
+  java.lang.Iterable[Context], epoch: Int)
   = this {
     val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
       .asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
@@ -79,7 +79,7 @@ class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.Obj
     * @return                 List of list of tuples of
     *                         (class, [probability, xmin, ymin, xmax, ymax])
     */
-  def objectDetectWithNDArray(input: java.util.List[NDArray], topK: Int):
+  def objectDetectWithNDArray(input: java.lang.Iterable[NDArray], topK: Int):
   java.util.List[java.util.List[ObjectDetectorOutput]] = {
     val ret = objDetector.objectDetectWithNDArray(convert(input.asScala.toIndexedSeq), Some(topK))
     (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
@@ -92,7 +92,7 @@ class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.Obj
     * @param topK             Number of result elements to return, sorted by probability
     * @return                 List of list of tuples of (class, probability)
     */
-  def imageBatchObjectDetect(inputBatch: java.util.List[BufferedImage], topK: Int):
+  def imageBatchObjectDetect(inputBatch: java.lang.Iterable[BufferedImage], topK: Int):
       java.util.List[java.util.List[ObjectDetectorOutput]] = {
     val ret = objDetector.imageBatchObjectDetect(inputBatch.asScala, Some(topK))
     (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
@@ -122,7 +122,7 @@ object ObjectDetector {
     org.apache.mxnet.infer.ImageClassifier.bufferedImageToPixels(resizedImage, inputImageShape)
   }
 
-  def loadInputBatch(inputImagePaths: java.util.List[String]): java.util.List[BufferedImage] = {
+  def loadInputBatch(inputImagePaths: java.lang.Iterable[String]): java.util.List[BufferedImage] = {
     org.apache.mxnet.infer.ImageClassifier
       .loadInputBatch(inputImagePaths.asScala.toList).toList.asJava
   }
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
index 146fe93..6c0871f 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
@@ -40,8 +40,8 @@ import scala.collection.JavaConverters._
 
 // JavaDoc description of class to be updated in https://issues.apache.org/jira/browse/MXNET-1178
 class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor){
-  def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc],
-           contexts: java.util.List[Context], epoch: Int)
+  def this(modelPathPrefix: String, inputDescriptors: java.lang.Iterable[DataDesc],
+           contexts: java.lang.Iterable[Context], epoch: Int)
   = this {
     val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
       .asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
@@ -97,10 +97,10 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor)
   }
 
   /**
-    * Takes input as List of one dimensional arrays and creates the NDArray needed for inference
+    * Takes input as List of one dimensional iterables and creates the NDArray needed for inference
     * The array will be reshaped based on the input descriptors.
     *
-    * @param input:            A List of a one-dimensional array.
+    * @param input:            A List of a one-dimensional iterables of DType Float.
                               An extra List is needed for when the model has more than one input.
     * @return                  Indexed sequence array of outputs
     */
@@ -118,10 +118,10 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor)
     * This method is useful when the input is a batch of data
     * Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
     *
-    * @param input             List of NDArrays
+    * @param input             Iterable of NDArrays
     * @return                  Output of predictions as NDArrays
     */
-  def predictWithNDArray(input: java.util.List[NDArray]):
+  def predictWithNDArray(input: java.lang.Iterable[NDArray]):
   java.util.List[NDArray] = {
     val ret = predictor.predictWithNDArray(convert(JavaConverters
       .asScalaIteratorConverter(input.iterator).asScala.toIndexedSeq))
diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java
index a5e6491..3219b5a 100644
--- a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java
+++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java
@@ -29,7 +29,9 @@ import org.mockito.Mockito;
 
 import java.awt.image.BufferedImage;
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 
 public class ObjectDetectorTest {
 
@@ -93,6 +95,17 @@ public class ObjectDetectorTest {
     }
 
     @Test
+    public void testObjectDetectorWithIterableOfBatchImage() {
+
+        Set<BufferedImage> batchImage = new HashSet<>();
+        batchImage.add(inputImage);
+        Mockito.when(objectDetector.imageBatchObjectDetect(batchImage, topK)).thenReturn(expectedResult);
+        List<List<ObjectDetectorOutput>> actualResult = objectDetector.imageBatchObjectDetect(batchImage, topK);
+        Mockito.verify(objectDetector, Mockito.times(1)).imageBatchObjectDetect(batchImage, topK);
+        Assert.assertEquals(expectedResult, actualResult);
+    }
+
+    @Test
     public void testObjectDetectorWithNDArrayInput() {
 
         NDArray inputArr = ObjectDetector.bufferedImageToPixels(inputImage, getTestShape());
@@ -103,4 +116,16 @@ public class ObjectDetectorTest {
         Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK);
         Assert.assertEquals(expectedResult, actualResult);
     }
+
+    @Test
+    public void testObjectDetectorWithIterableOfNDArrayInput() {
+
+        NDArray inputArr = ObjectDetector.bufferedImageToPixels(inputImage, getTestShape());
+        Set<NDArray> inputL = new HashSet<>();
+        inputL.add(inputArr);
+        Mockito.when(objectDetector.objectDetectWithNDArray(inputL, 5)).thenReturn(expectedResult);
+        List<List<ObjectDetectorOutput>> actualResult = objectDetector.objectDetectWithNDArray(inputL, topK);
+        Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK);
+        Assert.assertEquals(expectedResult, actualResult);
+    }
 }
diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java
index e7a6c96..0d83c74 100644
--- a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java
+++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java
@@ -25,9 +25,7 @@ import org.junit.Before;
 import org.junit.Test;
 import org.mockito.Mockito;
 
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
+import java.util.*;
 
 public class PredictorTest {
 
@@ -81,6 +79,31 @@ public class PredictorTest {
     }
 
     @Test
+    public void testPredictWithIterablesNDArray() {
+
+        float[] tmpArr = new float[224];
+        for (int y = 0; y < 224; y++)
+            tmpArr[y] = (int) (Math.random() * 10);
+
+        NDArray arr = new org.apache.mxnet.javaapi.NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0));
+
+        Set<NDArray> inputSet = new HashSet<>();
+        inputSet.add(arr);
+
+        NDArray expected = new NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0));
+        List<NDArray> expectedResult = new ArrayList<>();
+        expectedResult.add(expected);
+
+        Mockito.when(mockPredictor.predictWithNDArray(inputSet)).thenReturn(expectedResult);
+
+        List<NDArray> actualOutput = mockPredictor.predictWithNDArray(inputSet);
+
+        Mockito.verify(mockPredictor, Mockito.times(1)).predictWithNDArray(inputSet);
+
+        Assert.assertEquals(expectedResult, actualOutput);
+    }
+
+    @Test
     public void testPredictWithListOfFloatsAsInput() {
         List<List<Float>> input = new ArrayList<>();