You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2019/01/24 19:37:11 UTC
[incubator-mxnet] branch master updated: [MXNET-1293] Adding
Iterables instead of List to method signature for infer APIs in Java
(#13977)
This is an automated email from the ASF dual-hosted git repository.
lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 24412df [MXNET-1293] Adding Iterables instead of List to method signature for infer APIs in Java (#13977)
24412df is described below
commit 24412df84934b57f6e5ed7ac135bc6fe5402cff2
Author: Piyush Ghai <gh...@osu.edu>
AuthorDate: Thu Jan 24 11:36:55 2019 -0800
[MXNET-1293] Adding Iterables instead of List to method signature for infer APIs in Java (#13977)
* Added Iterables as input type instead of List in Predictor for Java
* Added Iterables to ObjectDetector API
* Added tests for Predictor API
* Added tests for ObjectDetector
---
.../mxnet/infer/javaapi/ObjectDetector.scala | 10 ++++----
.../org/apache/mxnet/infer/javaapi/Predictor.scala | 12 ++++-----
.../mxnet/infer/javaapi/ObjectDetectorTest.java | 25 +++++++++++++++++++
.../apache/mxnet/infer/javaapi/PredictorTest.java | 29 +++++++++++++++++++---
4 files changed, 62 insertions(+), 14 deletions(-)
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
index 3014f8d..05334e4 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
@@ -44,8 +44,8 @@ import scala.language.implicitConversions
*/
class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.ObjectDetector){
- def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc], contexts:
- java.util.List[Context], epoch: Int)
+ def this(modelPathPrefix: String, inputDescriptors: java.lang.Iterable[DataDesc], contexts:
+ java.lang.Iterable[Context], epoch: Int)
= this {
val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
.asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
@@ -79,7 +79,7 @@ class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.Obj
* @return List of list of tuples of
* (class, [probability, xmin, ymin, xmax, ymax])
*/
- def objectDetectWithNDArray(input: java.util.List[NDArray], topK: Int):
+ def objectDetectWithNDArray(input: java.lang.Iterable[NDArray], topK: Int):
java.util.List[java.util.List[ObjectDetectorOutput]] = {
val ret = objDetector.objectDetectWithNDArray(convert(input.asScala.toIndexedSeq), Some(topK))
(ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
@@ -92,7 +92,7 @@ class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.Obj
* @param topK Number of result elements to return, sorted by probability
* @return List of list of tuples of (class, probability)
*/
- def imageBatchObjectDetect(inputBatch: java.util.List[BufferedImage], topK: Int):
+ def imageBatchObjectDetect(inputBatch: java.lang.Iterable[BufferedImage], topK: Int):
java.util.List[java.util.List[ObjectDetectorOutput]] = {
val ret = objDetector.imageBatchObjectDetect(inputBatch.asScala, Some(topK))
(ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
@@ -122,7 +122,7 @@ object ObjectDetector {
org.apache.mxnet.infer.ImageClassifier.bufferedImageToPixels(resizedImage, inputImageShape)
}
- def loadInputBatch(inputImagePaths: java.util.List[String]): java.util.List[BufferedImage] = {
+ def loadInputBatch(inputImagePaths: java.lang.Iterable[String]): java.util.List[BufferedImage] = {
org.apache.mxnet.infer.ImageClassifier
.loadInputBatch(inputImagePaths.asScala.toList).toList.asJava
}
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
index 146fe93..6c0871f 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
@@ -40,8 +40,8 @@ import scala.collection.JavaConverters._
// JavaDoc description of class to be updated in https://issues.apache.org/jira/browse/MXNET-1178
class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor){
- def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc],
- contexts: java.util.List[Context], epoch: Int)
+ def this(modelPathPrefix: String, inputDescriptors: java.lang.Iterable[DataDesc],
+ contexts: java.lang.Iterable[Context], epoch: Int)
= this {
val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
.asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
@@ -97,10 +97,10 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor)
}
/**
- * Takes input as List of one dimensional arrays and creates the NDArray needed for inference
+ * Takes input as List of one dimensional iterables and creates the NDArray needed for inference
* The array will be reshaped based on the input descriptors.
*
- * @param input: A List of a one-dimensional array.
+ * @param input: A List of a one-dimensional iterables of DType Float.
An extra List is needed for when the model has more than one input.
* @return Indexed sequence array of outputs
*/
@@ -118,10 +118,10 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor)
* This method is useful when the input is a batch of data
* Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
*
- * @param input List of NDArrays
+ * @param input Iterable of NDArrays
* @return Output of predictions as NDArrays
*/
- def predictWithNDArray(input: java.util.List[NDArray]):
+ def predictWithNDArray(input: java.lang.Iterable[NDArray]):
java.util.List[NDArray] = {
val ret = predictor.predictWithNDArray(convert(JavaConverters
.asScalaIteratorConverter(input.iterator).asScala.toIndexedSeq))
diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java
index a5e6491..3219b5a 100644
--- a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java
+++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java
@@ -29,7 +29,9 @@ import org.mockito.Mockito;
import java.awt.image.BufferedImage;
import java.util.ArrayList;
+import java.util.HashSet;
import java.util.List;
+import java.util.Set;
public class ObjectDetectorTest {
@@ -93,6 +95,17 @@ public class ObjectDetectorTest {
}
@Test
+ public void testObjectDetectorWithIterableOfBatchImage() {
+
+ Set<BufferedImage> batchImage = new HashSet<>();
+ batchImage.add(inputImage);
+ Mockito.when(objectDetector.imageBatchObjectDetect(batchImage, topK)).thenReturn(expectedResult);
+ List<List<ObjectDetectorOutput>> actualResult = objectDetector.imageBatchObjectDetect(batchImage, topK);
+ Mockito.verify(objectDetector, Mockito.times(1)).imageBatchObjectDetect(batchImage, topK);
+ Assert.assertEquals(expectedResult, actualResult);
+ }
+
+ @Test
public void testObjectDetectorWithNDArrayInput() {
NDArray inputArr = ObjectDetector.bufferedImageToPixels(inputImage, getTestShape());
@@ -103,4 +116,16 @@ public class ObjectDetectorTest {
Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK);
Assert.assertEquals(expectedResult, actualResult);
}
+
+ @Test
+ public void testObjectDetectorWithIterableOfNDArrayInput() {
+
+ NDArray inputArr = ObjectDetector.bufferedImageToPixels(inputImage, getTestShape());
+ Set<NDArray> inputL = new HashSet<>();
+ inputL.add(inputArr);
+ Mockito.when(objectDetector.objectDetectWithNDArray(inputL, 5)).thenReturn(expectedResult);
+ List<List<ObjectDetectorOutput>> actualResult = objectDetector.objectDetectWithNDArray(inputL, topK);
+ Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK);
+ Assert.assertEquals(expectedResult, actualResult);
+ }
}
diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java
index e7a6c96..0d83c74 100644
--- a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java
+++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java
@@ -25,9 +25,7 @@ import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
+import java.util.*;
public class PredictorTest {
@@ -81,6 +79,31 @@ public class PredictorTest {
}
@Test
+ public void testPredictWithIterablesNDArray() {
+
+ float[] tmpArr = new float[224];
+ for (int y = 0; y < 224; y++)
+ tmpArr[y] = (int) (Math.random() * 10);
+
+ NDArray arr = new org.apache.mxnet.javaapi.NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0));
+
+ Set<NDArray> inputSet = new HashSet<>();
+ inputSet.add(arr);
+
+ NDArray expected = new NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0));
+ List<NDArray> expectedResult = new ArrayList<>();
+ expectedResult.add(expected);
+
+ Mockito.when(mockPredictor.predictWithNDArray(inputSet)).thenReturn(expectedResult);
+
+ List<NDArray> actualOutput = mockPredictor.predictWithNDArray(inputSet);
+
+ Mockito.verify(mockPredictor, Mockito.times(1)).predictWithNDArray(inputSet);
+
+ Assert.assertEquals(expectedResult, actualOutput);
+ }
+
+ @Test
public void testPredictWithListOfFloatsAsInput() {
List<List<Float>> input = new ArrayList<>();