You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/10/19 22:47:29 UTC
[incubator-mxnet] branch java-api updated: Java Inference api and
SSD example (#12830)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch java-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/java-api by this push:
new 2bc818e Java Inference api and SSD example (#12830)
2bc818e is described below
commit 2bc818e72a3f7029c210bf5860573f11ff421886
Author: Andrew Ayres <an...@gmail.com>
AuthorDate: Fri Oct 19 15:47:14 2018 -0700
Java Inference api and SSD example (#12830)
* New Java inference API and SSD example
* Adding license to java files and fixing SSD example
* Fixing SSD example to point to ObjectDetector instead of ImageClassifier
* Make scripts for object detector independent to os and hw cpu/gpu
* Added API Docs to Java Inference API. Small fixes for PR
* Cosmetic updates for API DOCS requested during PR
* Attempt to fix the CI Javafx compiler issue
* Migrate from Javafx to apache commons for Pair implementation
* Removing javafx from pom file
* Fixes to appease the ScalaStyle deity
* Minor fix in SSD script and Readme
* Added ObjectDetectorOutput which is a POJO for Object Detector to simplify the return type
* Removing Apache Commons Immutable Pair
* Adding license to new file
* Minor style fixes
* minor style fix
* Updating to be in scala style and not explicitly declare some unnecessary variables
---
.../infer/objectdetector/run_ssd_example.sh | 14 +-
...{run_ssd_example.sh => run_ssd_java_example.sh} | 16 +-
.../infer/javapi}/objectdetector/README.md | 4 +-
.../objectdetector/SSDClassifierExample.java | 199 +++++++++++++++++++++
.../mxnetexamples/infer/objectdetector/README.md | 4 +-
.../mxnet/infer/javaapi/ObjectDetector.scala | 106 +++++++++++
.../mxnet/infer/javaapi/ObjectDetectorOutput.scala | 34 ++++
.../org/apache/mxnet/infer/javaapi/Predictor.scala | 69 +++++++
8 files changed, 439 insertions(+), 7 deletions(-)
diff --git a/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh b/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
index 8cea892..adb8830 100755
--- a/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
+++ b/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
@@ -17,9 +17,21 @@
# specific language governing permissions and limitations
# under the License.
+hw_type=cpu
+if [[ $1 = gpu ]]
+then
+ hw_type=gpu
+fi
+
+platform=linux-x86_64
+
+if [[ $OSTYPE = [darwin]* ]]
+then
+ platform=osx-x86_64
+fi
MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd)
-CLASS_PATH=$MXNET_ROOT/scala-package/assembly/osx-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
# model dir and prefix
MODEL_DIR=$1
diff --git a/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh b/scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh
similarity index 66%
copy from scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
copy to scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh
index 8cea892..f444a3a 100755
--- a/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
+++ b/scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh
@@ -17,9 +17,21 @@
# specific language governing permissions and limitations
# under the License.
+hw_type=cpu
+if [[ $4 = gpu ]]
+then
+ hw_type=gpu
+fi
+
+platform=linux-x86_64
+
+if [[ $OSTYPE = [darwin]* ]]
+then
+ platform=osx-x86_64
+fi
MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd)
-CLASS_PATH=$MXNET_ROOT/scala-package/assembly/osx-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*:$MXNET_ROOT/scala-package/examples/src/main/scala/org/apache/mxnetexamples/api/java/infer/imageclassifier/*
# model dir and prefix
MODEL_DIR=$1
@@ -29,7 +41,7 @@ INPUT_IMG=$2
INPUT_DIR=$3
java -Xmx8G -cp $CLASS_PATH \
- org.apache.mxnetexamples.infer.objectdetector.SSDClassifierExample \
+ org.apache.mxnetexamples.infer.javapi.objectdetector.SSDClassifierExample \
--model-path-prefix $MODEL_DIR \
--input-image $INPUT_IMG \
--input-dir $INPUT_DIR
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/README.md
similarity index 97%
copy from scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
copy to scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/README.md
index 69328a4..63b9f92 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/README.md
@@ -31,7 +31,7 @@ You can download the files using the script `get_ssd_data.sh`. It will download
From the `scala-package/examples/scripts/infer/imageclassifier/` folder run:
```bash
-./get_resnet_data.sh
+./get_ssd_data.sh
```
**Note**: You may need to run `chmod +x get_resnet_data.sh` before running this script.
@@ -79,7 +79,7 @@ After the previous steps, you should be able to run the code using the following
From the `scala-package/examples/scripts/inferexample/objectdetector/` folder run:
```bash
-./run_ssd_example.sh ../models/resnet50_ssd_model ../images/dog.jpg ../images
+./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
```
**Notes**:
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/SSDClassifierExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/SSDClassifierExample.java
new file mode 100644
index 0000000..13f9d2d
--- /dev/null
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/SSDClassifierExample.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.infer.javapi.objectdetector;
+
+import org.apache.mxnet.infer.javaapi.ObjectDetectorOutput;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.mxnet.javaapi.*;
+import org.apache.mxnet.infer.javaapi.ObjectDetector;
+
+// scalastyle:off
+import java.awt.image.BufferedImage;
+// scalastyle:on
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import java.io.File;
+
+public class SSDClassifierExample {
+ @Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model")
+ private String modelPathPrefix = "/model/ssd_resnet50_512";
+ @Option(name = "--input-image", usage = "the input image")
+ private String inputImagePath = "/images/dog.jpg";
+ @Option(name = "--input-dir", usage = "the input batch of images directory")
+ private String inputImageDir = "/images/";
+
+ final static Logger logger = LoggerFactory.getLogger(SSDClassifierExample.class);
+
+ static List<List<ObjectDetectorOutput>>
+ runObjectDetectionSingle(String modelPathPrefix, String inputImagePath, List<Context> context) {
+ Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
+ List<DataDesc> inputDescriptors = new ArrayList<DataDesc>();
+ inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+ BufferedImage img = ObjectDetector.loadImageFromFile(inputImagePath);
+ ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
+ return objDet.imageObjectDetect(img, 3);
+ }
+
+ static List<List<List<ObjectDetectorOutput>>>
+ runObjectDetectionBatch(String modelPathPrefix, String inputImageDir, List<Context> context) {
+ Shape inputShape = new Shape(new int[]{1, 3, 512, 512});
+ List<DataDesc> inputDescriptors = new ArrayList<DataDesc>();
+ inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+ ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
+
+ // Loading batch of images from the directory path
+ List<List<String>> batchFiles = generateBatches(inputImageDir, 20);
+ List<List<List<ObjectDetectorOutput>>> outputList
+ = new ArrayList<List<List<ObjectDetectorOutput>>>();
+
+ for (List<String> batchFile : batchFiles) {
+ List<BufferedImage> imgList = ObjectDetector.loadInputBatch(batchFile);
+ // Running inference on batch of images loaded in previous step
+ List<List<ObjectDetectorOutput>> tmp
+ = objDet.imageBatchObjectDetect(imgList, 5);
+ outputList.add(tmp);
+ }
+ return outputList;
+ }
+
+ static List<List<String>> generateBatches(String inputImageDirPath, int batchSize) {
+ File dir = new File(inputImageDirPath);
+
+ List<List<String>> output = new ArrayList<List<String>>();
+ List<String> batch = new ArrayList<String>();
+ for (File imgFile : dir.listFiles()) {
+ batch.add(imgFile.getPath());
+ if (batch.size() == batchSize) {
+ output.add(batch);
+ batch = new ArrayList<String>();
+ }
+ }
+ if (batch.size() > 0) {
+ output.add(batch);
+ }
+ return output;
+ }
+
+ public static void main(String[] args) {
+ SSDClassifierExample inst = new SSDClassifierExample();
+ CmdLineParser parser = new CmdLineParser(inst);
+ try {
+ parser.parseArgument(args);
+ } catch (Exception e) {
+ logger.error(e.getMessage(), e);
+ parser.printUsage(System.err);
+ System.exit(1);
+ }
+
+ String mdprefixDir = inst.modelPathPrefix;
+ String imgPath = inst.inputImagePath;
+ String imgDir = inst.inputImageDir;
+
+ if (!checkExist(Arrays.asList(mdprefixDir + "-symbol.json", imgDir, imgPath))) {
+ logger.error("Model or input image path does not exist");
+ System.exit(1);
+ }
+
+ List<Context> context = new ArrayList<Context>();
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
+ context.add(Context.gpu());
+ } else {
+ context.add(Context.cpu());
+ }
+
+ try {
+ Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
+ Shape outputShape = new Shape(new int[] {1, 6132, 6});
+
+
+ int width = inputShape.get(2);
+ int height = inputShape.get(3);
+ String outputStr = "\n";
+
+ List<List<ObjectDetectorOutput>> output
+ = runObjectDetectionSingle(mdprefixDir, imgPath, context);
+
+ for (List<ObjectDetectorOutput> ele : output) {
+ for (ObjectDetectorOutput i : ele) {
+ outputStr += "Class: " + i.getClassName() + "\n";
+ outputStr += "Probabilties: " + i.getProbability() + "\n";
+
+ List<Float> coord = Arrays.asList(i.getXMin() * width,
+ i.getXMax() * height, i.getYMin() * width, i.getYMax() * height);
+ StringBuilder sb = new StringBuilder();
+ for (float c: coord) {
+ sb.append(", ").append(c);
+ }
+ outputStr += "Coord:" + sb.substring(2)+ "\n";
+ }
+ }
+ logger.info(outputStr);
+
+ List<List<List<ObjectDetectorOutput>>> outputList =
+ runObjectDetectionBatch(mdprefixDir, imgDir, context);
+
+ outputStr = "\n";
+ int index = 0;
+ for (List<List<ObjectDetectorOutput>> i: outputList) {
+ for (List<ObjectDetectorOutput> j : i) {
+ outputStr += "*** Image " + (index + 1) + "***" + "\n";
+ for (ObjectDetectorOutput k : j) {
+ outputStr += "Class: " + k.getClassName() + "\n";
+ outputStr += "Probabilties: " + k.getProbability() + "\n";
+ List<Float> coord = Arrays.asList(k.getXMin() * width,
+ k.getXMax() * height, k.getYMin() * width, k.getYMax() * height);
+
+ StringBuilder sb = new StringBuilder();
+ for (float c : coord) {
+ sb.append(", ").append(c);
+ }
+ outputStr += "Coord:" + sb.substring(2) + "\n";
+ }
+ index++;
+ }
+ }
+ logger.info(outputStr);
+
+ } catch (Exception e) {
+ logger.error(e.getMessage(), e);
+ parser.printUsage(System.err);
+ System.exit(1);
+ }
+ System.exit(0);
+ }
+
+ static Boolean checkExist(List<String> arr) {
+ Boolean exist = true;
+ for (String item : arr) {
+ exist = new File(item).exists() && exist;
+ if (!exist) {
+ logger.error("Cannot find: " + item);
+ }
+ }
+ return exist;
+ }
+}
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
index 69328a4..bf4a44a 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
@@ -31,7 +31,7 @@ You can download the files using the script `get_ssd_data.sh`. It will download
From the `scala-package/examples/scripts/infer/imageclassifier/` folder run:
```bash
-./get_resnet_data.sh
+./get_ssd_data.sh
```
**Note**: You may need to run `chmod +x get_resnet_data.sh` before running this script.
@@ -79,7 +79,7 @@ After the previous steps, you should be able to run the code using the following
From the `scala-package/examples/scripts/inferexample/objectdetector/` folder run:
```bash
-./run_ssd_example.sh ../models/resnet50_ssd_model ../images/dog.jpg ../images
+./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
```
**Notes**:
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
new file mode 100644
index 0000000..6cd3df6
--- /dev/null
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi
+
+// scalastyle:off
+import java.awt.image.BufferedImage
+// scalastyle:on
+
+import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray}
+
+import scala.collection.JavaConverters
+import scala.collection.JavaConverters._
+
+
+class ObjectDetector(val objDetector: org.apache.mxnet.infer.ObjectDetector){
+
+ def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc], contexts:
+ java.util.List[Context], epoch: Int)
+ = this {
+ val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
+ .asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
+ val inContexts = (contexts.asScala.toList map {a => a: org.apache.mxnet.Context}).toArray
+ // scalastyle:off
+ new org.apache.mxnet.infer.ObjectDetector(modelPathPrefix, informationDesc, inContexts, Some(epoch))
+ // scalastyle:on
+ }
+
+ /**
+ * Detects objects and returns bounding boxes with corresponding class/label
+ *
+ * @param inputImage Path prefix of the input image
+ * @param topK Number of result elements to return, sorted by probability
+ * @return List of list of tuples of
+ * (class, [probability, xmin, ymin, xmax, ymax])
+ */
+ def imageObjectDetect(inputImage: BufferedImage, topK: Int):
+ java.util.List[java.util.List[ObjectDetectorOutput]] = {
+ val ret = objDetector.imageObjectDetect(inputImage, Some(topK))
+ (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
+ }
+
+ /**
+ * Takes input images as NDArrays. Useful when you want to perform multiple operations on
+ * the input array, or when you want to pass a batch of input images.
+ *
+ * @param input Indexed Sequence of NDArrays
+ * @param topK (Optional) How many top_k (sorting will be based on the last axis)
+ * elements to return. If not passed, returns all unsorted output.
+ * @return List of list of tuples of
+ * (class, [probability, xmin, ymin, xmax, ymax])
+ */
+ def objectDetectWithNDArray(input: java.util.List[NDArray], topK: Int):
+ java.util.List[java.util.List[ObjectDetectorOutput]] = {
+ val ret = objDetector.objectDetectWithNDArray(convert(input.asScala.toIndexedSeq), Some(topK))
+ (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
+ }
+
+ /**
+ * To classify batch of input images according to the provided model
+ *
+ * @param inputBatch Input array of buffered images
+ * @param topK Number of result elements to return, sorted by probability
+ * @return List of list of tuples of (class, probability)
+ */
+ def imageBatchObjectDetect(inputBatch: java.util.List[BufferedImage], topK: Int):
+ java.util.List[java.util.List[ObjectDetectorOutput]] = {
+ val ret = objDetector.imageBatchObjectDetect(inputBatch.asScala, Some(topK))
+ (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
+ }
+
+ def convert[B, A <% B](l: IndexedSeq[A]): IndexedSeq[B] = l map { a => a: B }
+
+}
+
+
+object ObjectDetector {
+ implicit def fromObjectDetector(OD: org.apache.mxnet.infer.ObjectDetector):
+ ObjectDetector = new ObjectDetector(OD)
+
+ implicit def toObjectDetector(jOD: ObjectDetector):
+ org.apache.mxnet.infer.ObjectDetector = jOD.objDetector
+
+ def loadImageFromFile(inputImagePath: String): BufferedImage = {
+ org.apache.mxnet.infer.ImageClassifier.loadImageFromFile(inputImagePath)
+ }
+
+ def loadInputBatch(inputImagePaths: java.util.List[String]): java.util.List[BufferedImage] = {
+ org.apache.mxnet.infer.ImageClassifier
+ .loadInputBatch(inputImagePaths.asScala.toList).toList.asJava
+ }
+}
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala
new file mode 100644
index 0000000..13369c8
--- /dev/null
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi
+
+class ObjectDetectorOutput (className: String, args: Array[Float]){
+
+ def getClassName: String = className
+
+ def getProbability: Float = args(0)
+
+ def getXMin: Float = args(1)
+
+ def getXMax: Float = args(2)
+
+ def getYMin: Float = args(3)
+
+ def getYMax: Float = args(4)
+
+}
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
new file mode 100644
index 0000000..26ccd06
--- /dev/null
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi
+
+import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray}
+
+import scala.collection.JavaConverters
+import scala.collection.JavaConverters._
+
+class Predictor(val predictor: org.apache.mxnet.infer.Predictor){
+ def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc],
+ contexts: java.util.List[Context], epoch: Int)
+ = this {
+ val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
+ .asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
+ val inContexts = (contexts.asScala.toList map {a => a: org.apache.mxnet.Context}).toArray
+ new org.apache.mxnet.infer.Predictor(modelPathPrefix, informationDesc, inContexts, Some(epoch))
+ }
+
+
+ /**
+ * Takes input as List of one dimensional arrays and creates the NDArray needed for inference
+ * The array will be reshaped based on the input descriptors.
+ *
+ * @param input: A List of a one-dimensional array.
+ An extra List is needed for when the model has more than one input.
+ * @return Indexed sequence array of outputs
+ */
+ def predict(input: java.util.List[java.util.List[Float]]):
+ java.util.List[java.util.List[Float]] = {
+ val in = JavaConverters.asScalaIteratorConverter(input.iterator).asScala.toIndexedSeq
+ (predictor.predict(in map {a => a.asScala.toArray}) map {b => b.toList.asJava}).asJava
+ }
+
+
+ /**
+ * Predict using NDArray as input
+ * This method is useful when the input is a batch of data
+ * Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
+ *
+ * @param input List of NDArrays
+ * @return Output of predictions as NDArrays
+ */
+ def predictWithNDArray(input: java.util.List[NDArray]):
+ java.util.List[NDArray] = {
+ val ret = predictor.predictWithNDArray(convert(JavaConverters
+ .asScalaIteratorConverter(input.iterator).asScala.toIndexedSeq))
+ // TODO: For some reason the implicit wasn't working here when trying to use convert.
+ // So did it this way. Needs to be figured out
+ (ret map {a => new NDArray(a)}).asJava
+ }
+
+ private def convert[B, A <% B](l: IndexedSeq[A]): IndexedSeq[B] = l map { a => a: B }
+}