You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/10/19 22:47:16 UTC

[GitHub] nswamy closed pull request #12830: Java Inference api and SSD example

nswamy closed pull request #12830: Java Inference api and SSD example
URL: https://github.com/apache/incubator-mxnet/pull/12830
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh b/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
index 8cea892b580..adb8830de06 100755
--- a/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
+++ b/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
@@ -17,9 +17,21 @@
 # specific language governing permissions and limitations
 # under the License.
 
+hw_type=cpu
+if [[ $1 = gpu ]]
+then
+    hw_type=gpu
+fi
+
+platform=linux-x86_64
+
+if [[ $OSTYPE = [darwin]* ]]
+then
+    platform=osx-x86_64
+fi
 
 MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd)
-CLASS_PATH=$MXNET_ROOT/scala-package/assembly/osx-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
 
 # model dir and prefix
 MODEL_DIR=$1
diff --git a/scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh b/scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh
new file mode 100755
index 00000000000..f444a3a59af
--- /dev/null
+++ b/scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh
@@ -0,0 +1,47 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+hw_type=cpu
+if [[ $4 = gpu ]]
+then
+    hw_type=gpu
+fi
+
+platform=linux-x86_64
+
+if [[ $OSTYPE = [darwin]* ]]
+then
+    platform=osx-x86_64
+fi
+
+MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd)
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*:$MXNET_ROOT/scala-package/examples/src/main/scala/org/apache/mxnetexamples/api/java/infer/imageclassifier/*
+
+# model dir and prefix
+MODEL_DIR=$1
+# input image
+INPUT_IMG=$2
+# which input image dir
+INPUT_DIR=$3
+
+java -Xmx8G -cp $CLASS_PATH \
+	org.apache.mxnetexamples.infer.javapi.objectdetector.SSDClassifierExample \
+	--model-path-prefix $MODEL_DIR \
+	--input-image $INPUT_IMG \
+	--input-dir $INPUT_DIR
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/README.md
new file mode 100644
index 00000000000..63b9f929a82
--- /dev/null
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/README.md
@@ -0,0 +1,116 @@
+# Single Shot Multi Object Detection using Scala Inference API
+
+In this example, you will learn how to use Scala Inference API to run Inference on pre-trained Single Shot Multi Object Detection (SSD) MXNet model.
+
+The model is trained on the [Pascal VOC 2012 dataset](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html). The network is a SSD model built on Resnet50 as base network to extract image features. The model is trained to detect the following entities (classes): ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']. For more details about the model, you can refer to the [MXNet SSD example](https://github.com/apache/incubator-mxnet/tree/master/example/ssd).
+
+
+## Contents
+
+1. [Prerequisites](#prerequisites)
+2. [Download artifacts](#download-artifacts)
+3. [Setup datapath and parameters](#setup-datapath-and-parameters)
+4. [Run the image inference example](#run-the-image-inference-example)
+5. [Infer APIs](#infer-api-details)
+6. [Next steps](#next-steps)
+
+
+## Prerequisites
+
+1. MXNet
+2. MXNet Scala Package
+3. [IntelliJ IDE (or alternative IDE) project setup](http://mxnet.incubator.apache.org/tutorials/scala/mxnet_scala_on_intellij.html) with the MXNet Scala Package
+4. wget
+
+
+## Setup Guide
+
+### Download Artifacts
+#### Step 1
+You can download the files using the script `get_ssd_data.sh`. It will download and place the model files in a `model` folder and the test image files in a `image` folder in the current directory.
+From the `scala-package/examples/scripts/infer/imageclassifier/` folder run:
+
+```bash
+./get_ssd_data.sh
+```
+
+**Note**: You may need to run `chmod +x get_resnet_data.sh` before running this script.
+
+Alternatively use the following links to download the Symbol and Params files via your browser:
+- [resnet50_ssd_model-symbol.json](https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-symbol.json)
+- [resnet50_ssd_model-0000.params](https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-0000.params)
+- [synset.txt](https://github.com/awslabs/mxnet-model-server/blob/master/examples/ssd/synset.txt)
+
+In the pre-trained model, the `input_name` is `data` and shape is `(1, 3, 512, 512)`.
+This shape translates to: a batch of `1` image, the image has color and uses `3` channels (RGB), and the image has the dimensions of `512` pixels in height by `512` pixels in width.
+
+`image/jpeg` is the expected input type, since this example's image pre-processor only supports the handling of binary JPEG images.
+
+The output shape is `(1, 6132, 6)`. As with the input, the `1` is the number of images. `6132` is the number of prediction results, and `6` is for the size of each prediction. Each prediction contains the following components:
+- `Class`
+- `Accuracy`
+- `Xmin`
+- `Ymin`
+- `Xmax`
+- `Ymax`
+
+
+### Setup Datapath and Parameters
+#### Step 2
+The code `Line 31: val baseDir = System.getProperty("user.dir")` in the example will automatically searches the work directory you have defined. Please put the files in your [work directory](https://stackoverflow.com/questions/16239130/java-user-dir-property-what-exactly-does-it-mean). <!-- how do you define the work directory? -->
+
+Alternatively, if you would like to use your own path, please change line 31 into your own path
+```scala
+val baseDir = <Your Own Path>
+```
+
+The followings is the parameters defined for this example, you can find more information in the `class SSDClassifierExample`.
+
+| Argument                      | Comments                                 |
+| ----------------------------- | ---------------------------------------- |
+| `model-path-prefix`                   | Folder path with prefix to the model (including json, params, and any synset file). |
+| `input-image`                 | The image to run inference on. |
+| `input-dir`                   | The directory of images to run inference on. |
+
+
+## How to Run Inference
+After the previous steps, you should be able to run the code using the following script that will pass all of the required parameters to the Infer API.
+
+From the `scala-package/examples/scripts/inferexample/objectdetector/` folder run:
+
+```bash
+./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
+```
+
+**Notes**:
+* These are relative paths to this script.
+* You may need to run `chmod +x run_ssd_example.sh` before running this script.
+
+The example should give expected output as shown below:
+```
+Class: car
+Probabilties: 0.99847263
+(Coord:,312.21335,72.0291,456.01443,150.66176)
+Class: bicycle
+Probabilties: 0.90473825
+(Coord:,155.95807,149.96362,383.8369,418.94513)
+Class: dog
+Probabilties: 0.8226818
+(Coord:,83.82353,179.13998,206.63783,476.7875)
+```
+the outputs come from the the input image, with top3 predictions picked.
+
+
+## Infer API Details
+This example uses ObjectDetector class provided by MXNet's scala package Infer APIs. It provides methods to load the images, create NDArray out of Java BufferedImage and run prediction using Classifier and Predictor APIs.
+
+
+## References
+This documentation used the model and inference setup guide from the [MXNet Model Server SSD example](https://github.com/awslabs/mxnet-model-server/blob/master/examples/ssd/README.md).
+
+
+## Next Steps
+
+Check out the following related tutorials and examples for the Infer API:
+
+* [Image Classification with the MXNet Scala Infer API](../imageclassifier/README.md)
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/SSDClassifierExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/SSDClassifierExample.java
new file mode 100644
index 00000000000..13f9d2d9a3e
--- /dev/null
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/SSDClassifierExample.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.infer.javapi.objectdetector;
+
+import org.apache.mxnet.infer.javaapi.ObjectDetectorOutput;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.mxnet.javaapi.*;
+import org.apache.mxnet.infer.javaapi.ObjectDetector;
+
+// scalastyle:off
+import java.awt.image.BufferedImage;
+// scalastyle:on
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import java.io.File;
+
+public class SSDClassifierExample {
+	@Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model")
+	private String modelPathPrefix = "/model/ssd_resnet50_512";
+	@Option(name = "--input-image", usage = "the input image")
+	private String inputImagePath = "/images/dog.jpg";
+	@Option(name = "--input-dir", usage = "the input batch of images directory")
+	private String inputImageDir = "/images/";
+	
+	final static Logger logger = LoggerFactory.getLogger(SSDClassifierExample.class);
+	
+	static List<List<ObjectDetectorOutput>>
+	runObjectDetectionSingle(String modelPathPrefix, String inputImagePath, List<Context> context) {
+		Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
+		List<DataDesc> inputDescriptors = new ArrayList<DataDesc>();
+		inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+		BufferedImage img = ObjectDetector.loadImageFromFile(inputImagePath);
+		ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
+		return objDet.imageObjectDetect(img, 3);
+	}
+	
+	static List<List<List<ObjectDetectorOutput>>>
+	runObjectDetectionBatch(String modelPathPrefix, String inputImageDir, List<Context> context) {
+		Shape inputShape = new Shape(new int[]{1, 3, 512, 512});
+		List<DataDesc> inputDescriptors = new ArrayList<DataDesc>();
+		inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+		ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
+		
+		// Loading batch of images from the directory path
+		List<List<String>> batchFiles = generateBatches(inputImageDir, 20);
+		List<List<List<ObjectDetectorOutput>>> outputList
+				= new ArrayList<List<List<ObjectDetectorOutput>>>();
+		
+		for (List<String> batchFile : batchFiles) {
+			List<BufferedImage> imgList = ObjectDetector.loadInputBatch(batchFile);
+			// Running inference on batch of images loaded in previous step
+			List<List<ObjectDetectorOutput>> tmp
+					= objDet.imageBatchObjectDetect(imgList, 5);
+			outputList.add(tmp);
+		}
+		return outputList;
+	}
+	
+	static List<List<String>> generateBatches(String inputImageDirPath, int batchSize) {
+		File dir = new File(inputImageDirPath);
+
+		List<List<String>> output = new ArrayList<List<String>>();
+		List<String> batch = new ArrayList<String>();
+		for (File imgFile : dir.listFiles()) {
+			batch.add(imgFile.getPath());
+			if (batch.size() == batchSize) {
+				output.add(batch);
+				batch = new ArrayList<String>();
+			}
+		}
+		if (batch.size() > 0) {
+			output.add(batch);
+		}
+		return output;
+	}
+	
+	public static void main(String[] args) {
+		SSDClassifierExample inst = new SSDClassifierExample();
+		CmdLineParser parser  = new CmdLineParser(inst);
+		try {
+			parser.parseArgument(args);
+		} catch (Exception e) {
+			logger.error(e.getMessage(), e);
+			parser.printUsage(System.err);
+			System.exit(1);
+		}
+		
+		String mdprefixDir = inst.modelPathPrefix;
+		String imgPath = inst.inputImagePath;
+		String imgDir = inst.inputImageDir;
+		
+		if (!checkExist(Arrays.asList(mdprefixDir + "-symbol.json", imgDir, imgPath))) {
+			logger.error("Model or input image path does not exist");
+			System.exit(1);
+		}
+		
+		List<Context> context = new ArrayList<Context>();
+		if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+				Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
+			context.add(Context.gpu());
+		} else {
+			context.add(Context.cpu());
+		}
+		
+		try {
+			Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
+			Shape outputShape = new Shape(new int[] {1, 6132, 6});
+			
+			
+			int width = inputShape.get(2);
+			int height = inputShape.get(3);
+			String outputStr = "\n";
+			
+			List<List<ObjectDetectorOutput>> output
+					= runObjectDetectionSingle(mdprefixDir, imgPath, context);
+			
+			for (List<ObjectDetectorOutput> ele : output) {
+				for (ObjectDetectorOutput i : ele) {
+					outputStr += "Class: " + i.getClassName() + "\n";
+					outputStr += "Probabilties: " + i.getProbability() + "\n";
+					
+					List<Float> coord = Arrays.asList(i.getXMin() * width,
+							i.getXMax() * height, i.getYMin() * width, i.getYMax() * height);
+					StringBuilder sb = new StringBuilder();
+					for (float c: coord) {
+						sb.append(", ").append(c);
+					}
+					outputStr += "Coord:" + sb.substring(2)+ "\n";
+				}
+			}
+			logger.info(outputStr);
+			
+			List<List<List<ObjectDetectorOutput>>> outputList =
+					runObjectDetectionBatch(mdprefixDir, imgDir, context);
+			
+			outputStr = "\n";
+			int index = 0;
+			for (List<List<ObjectDetectorOutput>> i: outputList) {
+				for (List<ObjectDetectorOutput> j : i) {
+					outputStr += "*** Image " + (index + 1) + "***" + "\n";
+					for (ObjectDetectorOutput k : j) {
+						outputStr += "Class: " + k.getClassName() + "\n";
+						outputStr += "Probabilties: " + k.getProbability() + "\n";
+						List<Float> coord = Arrays.asList(k.getXMin() * width,
+								k.getXMax() * height, k.getYMin() * width, k.getYMax() * height);
+						
+						StringBuilder sb = new StringBuilder();
+						for (float c : coord) {
+							sb.append(", ").append(c);
+						}
+						outputStr += "Coord:" + sb.substring(2) + "\n";
+					}
+					index++;
+				}
+			}
+			logger.info(outputStr);
+			
+		} catch (Exception e) {
+			logger.error(e.getMessage(), e);
+			parser.printUsage(System.err);
+			System.exit(1);
+		}
+		System.exit(0);
+	}
+	
+	static Boolean checkExist(List<String> arr)  {
+		Boolean exist = true;
+		for (String item : arr) {
+			exist = new File(item).exists() && exist;
+			if (!exist) {
+				logger.error("Cannot find: " + item);
+			}
+		}
+		return exist;
+	}
+}
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
index 69328a44bab..bf4a44a76d0 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
@@ -31,7 +31,7 @@ You can download the files using the script `get_ssd_data.sh`. It will download
 From the `scala-package/examples/scripts/infer/imageclassifier/` folder run:
 
 ```bash
-./get_resnet_data.sh
+./get_ssd_data.sh
 ```
 
 **Note**: You may need to run `chmod +x get_resnet_data.sh` before running this script.
@@ -79,7 +79,7 @@ After the previous steps, you should be able to run the code using the following
 From the `scala-package/examples/scripts/inferexample/objectdetector/` folder run:
 
 ```bash
-./run_ssd_example.sh ../models/resnet50_ssd_model ../images/dog.jpg ../images
+./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
 ```
 
 **Notes**:
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
new file mode 100644
index 00000000000..6cd3df6b896
--- /dev/null
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi
+
+// scalastyle:off
+import java.awt.image.BufferedImage
+// scalastyle:on
+
+import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray}
+
+import scala.collection.JavaConverters
+import scala.collection.JavaConverters._
+
+
+class ObjectDetector(val objDetector: org.apache.mxnet.infer.ObjectDetector){
+
+  def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc], contexts:
+  java.util.List[Context], epoch: Int)
+  = this {
+    val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
+      .asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
+    val inContexts = (contexts.asScala.toList map {a => a: org.apache.mxnet.Context}).toArray
+    // scalastyle:off
+    new org.apache.mxnet.infer.ObjectDetector(modelPathPrefix, informationDesc, inContexts, Some(epoch))
+    // scalastyle:on
+  }
+
+  /**
+    * Detects objects and returns bounding boxes with corresponding class/label
+    *
+    * @param inputImage       Path prefix of the input image
+    * @param topK             Number of result elements to return, sorted by probability
+    * @return                 List of list of tuples of
+    *                         (class, [probability, xmin, ymin, xmax, ymax])
+    */
+  def imageObjectDetect(inputImage: BufferedImage, topK: Int):
+  java.util.List[java.util.List[ObjectDetectorOutput]] = {
+    val ret = objDetector.imageObjectDetect(inputImage, Some(topK))
+    (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
+  }
+
+  /**
+    * Takes input images as NDArrays. Useful when you want to perform multiple operations on
+    * the input array, or when you want to pass a batch of input images.
+    *
+    * @param input            Indexed Sequence of NDArrays
+    * @param topK             (Optional) How many top_k (sorting will be based on the last axis)
+    *                         elements to return. If not passed, returns all unsorted output.
+    * @return                 List of list of tuples of
+    *                         (class, [probability, xmin, ymin, xmax, ymax])
+    */
+  def objectDetectWithNDArray(input: java.util.List[NDArray], topK: Int):
+  java.util.List[java.util.List[ObjectDetectorOutput]] = {
+    val ret = objDetector.objectDetectWithNDArray(convert(input.asScala.toIndexedSeq), Some(topK))
+    (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
+  }
+
+  /**
+    * To classify batch of input images according to the provided model
+    *
+    * @param inputBatch       Input array of buffered images
+    * @param topK             Number of result elements to return, sorted by probability
+    * @return                 List of list of tuples of (class, probability)
+    */
+  def imageBatchObjectDetect(inputBatch: java.util.List[BufferedImage], topK: Int):
+      java.util.List[java.util.List[ObjectDetectorOutput]] = {
+    val ret = objDetector.imageBatchObjectDetect(inputBatch.asScala, Some(topK))
+    (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
+  }
+
+  def convert[B, A <% B](l: IndexedSeq[A]): IndexedSeq[B] = l map { a => a: B }
+
+}
+
+
+object ObjectDetector {
+  implicit def fromObjectDetector(OD: org.apache.mxnet.infer.ObjectDetector):
+    ObjectDetector = new ObjectDetector(OD)
+
+  implicit def toObjectDetector(jOD: ObjectDetector):
+    org.apache.mxnet.infer.ObjectDetector = jOD.objDetector
+
+  def loadImageFromFile(inputImagePath: String): BufferedImage = {
+    org.apache.mxnet.infer.ImageClassifier.loadImageFromFile(inputImagePath)
+  }
+
+  def loadInputBatch(inputImagePaths: java.util.List[String]): java.util.List[BufferedImage] = {
+    org.apache.mxnet.infer.ImageClassifier
+      .loadInputBatch(inputImagePaths.asScala.toList).toList.asJava
+  }
+}
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala
new file mode 100644
index 00000000000..13369c8fcef
--- /dev/null
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi
+
+class ObjectDetectorOutput (className: String, args: Array[Float]){
+
+  def getClassName: String = className
+
+  def getProbability: Float = args(0)
+
+  def getXMin: Float = args(1)
+
+  def getXMax: Float = args(2)
+
+  def getYMin: Float = args(3)
+
+  def getYMax: Float = args(4)
+
+}
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
new file mode 100644
index 00000000000..26ccd06cf46
--- /dev/null
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi
+
+import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray}
+
+import scala.collection.JavaConverters
+import scala.collection.JavaConverters._
+
+class Predictor(val predictor: org.apache.mxnet.infer.Predictor){
+  def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc],
+           contexts: java.util.List[Context], epoch: Int)
+  = this {
+    val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
+      .asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
+    val inContexts = (contexts.asScala.toList map {a => a: org.apache.mxnet.Context}).toArray
+    new org.apache.mxnet.infer.Predictor(modelPathPrefix, informationDesc, inContexts, Some(epoch))
+  }
+
+
+  /**
+    * Takes input as List of one dimensional arrays and creates the NDArray needed for inference
+    * The array will be reshaped based on the input descriptors.
+    *
+    * @param input:            A List of a one-dimensional array.
+                              An extra List is needed for when the model has more than one input.
+    * @return                  Indexed sequence array of outputs
+    */
+  def predict(input: java.util.List[java.util.List[Float]]):
+  java.util.List[java.util.List[Float]] = {
+    val in = JavaConverters.asScalaIteratorConverter(input.iterator).asScala.toIndexedSeq
+    (predictor.predict(in map {a => a.asScala.toArray}) map {b => b.toList.asJava}).asJava
+  }
+
+
+  /**
+    * Predict using NDArray as input
+    * This method is useful when the input is a batch of data
+    * Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
+    *
+    * @param input       List of NDArrays
+    * @return                  Output of predictions as NDArrays
+    */
+  def predictWithNDArray(input: java.util.List[NDArray]):
+  java.util.List[NDArray] = {
+    val ret = predictor.predictWithNDArray(convert(JavaConverters
+      .asScalaIteratorConverter(input.iterator).asScala.toIndexedSeq))
+    // TODO: For some reason the implicit wasn't working here when trying to use convert.
+    // So did it this way. Needs to be figured out
+    (ret map {a => new NDArray(a)}).asJava
+  }
+
+  private def convert[B, A <% B](l: IndexedSeq[A]): IndexedSeq[B] = l map { a => a: B }
+}


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services