You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/10/19 22:47:29 UTC

[incubator-mxnet] branch java-api updated: Java Inference api and SSD example (#12830)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch java-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/java-api by this push:
     new 2bc818e  Java Inference api and SSD example (#12830)
2bc818e is described below

commit 2bc818e72a3f7029c210bf5860573f11ff421886
Author: Andrew Ayres <an...@gmail.com>
AuthorDate: Fri Oct 19 15:47:14 2018 -0700

    Java Inference api and SSD example (#12830)
    
    * New Java inference API and SSD example
    
    * Adding license to java files and fixing SSD example
    
    * Fixing SSD example to point to ObjectDetector instead of ImageClassifier
    
    * Make scripts for object detector independent to os and hw cpu/gpu
    
    * Added API Docs to Java Inference API. Small fixes for PR
    
    * Cosmetic updates for API DOCS requested during PR
    
    * Attempt to fix the CI Javafx compiler issue
    
    * Migrate from Javafx to apache commons for Pair implementation
    
    * Removing javafx from pom file
    
    * Fixes to appease the ScalaStyle deity
    
    * Minor fix in SSD script and Readme
    
    * Added ObjectDetectorOutput which is a POJO for Object Detector to simplify the return type
    
    * Removing Apache Commons Immutable Pair
    
    * Adding license to new file
    
    * Minor style fixes
    
    * minor style fix
    
    * Updating to be in scala style and not explicitly declare some unnecessary variables
---
 .../infer/objectdetector/run_ssd_example.sh        |  14 +-
 ...{run_ssd_example.sh => run_ssd_java_example.sh} |  16 +-
 .../infer/javapi}/objectdetector/README.md         |   4 +-
 .../objectdetector/SSDClassifierExample.java       | 199 +++++++++++++++++++++
 .../mxnetexamples/infer/objectdetector/README.md   |   4 +-
 .../mxnet/infer/javaapi/ObjectDetector.scala       | 106 +++++++++++
 .../mxnet/infer/javaapi/ObjectDetectorOutput.scala |  34 ++++
 .../org/apache/mxnet/infer/javaapi/Predictor.scala |  69 +++++++
 8 files changed, 439 insertions(+), 7 deletions(-)

diff --git a/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh b/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
index 8cea892..adb8830 100755
--- a/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
+++ b/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
@@ -17,9 +17,21 @@
 # specific language governing permissions and limitations
 # under the License.
 
+hw_type=cpu
+if [[ $1 = gpu ]]
+then
+    hw_type=gpu
+fi
+
+platform=linux-x86_64
+
+if [[ $OSTYPE = [darwin]* ]]
+then
+    platform=osx-x86_64
+fi
 
 MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd)
-CLASS_PATH=$MXNET_ROOT/scala-package/assembly/osx-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
 
 # model dir and prefix
 MODEL_DIR=$1
diff --git a/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh b/scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh
similarity index 66%
copy from scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
copy to scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh
index 8cea892..f444a3a 100755
--- a/scala-package/examples/scripts/infer/objectdetector/run_ssd_example.sh
+++ b/scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh
@@ -17,9 +17,21 @@
 # specific language governing permissions and limitations
 # under the License.
 
+hw_type=cpu
+if [[ $4 = gpu ]]
+then
+    hw_type=gpu
+fi
+
+platform=linux-x86_64
+
+if [[ $OSTYPE = [darwin]* ]]
+then
+    platform=osx-x86_64
+fi
 
 MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd)
-CLASS_PATH=$MXNET_ROOT/scala-package/assembly/osx-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*:$MXNET_ROOT/scala-package/examples/src/main/scala/org/apache/mxnetexamples/api/java/infer/imageclassifier/*
 
 # model dir and prefix
 MODEL_DIR=$1
@@ -29,7 +41,7 @@ INPUT_IMG=$2
 INPUT_DIR=$3
 
 java -Xmx8G -cp $CLASS_PATH \
-	org.apache.mxnetexamples.infer.objectdetector.SSDClassifierExample \
+	org.apache.mxnetexamples.infer.javapi.objectdetector.SSDClassifierExample \
 	--model-path-prefix $MODEL_DIR \
 	--input-image $INPUT_IMG \
 	--input-dir $INPUT_DIR
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/README.md
similarity index 97%
copy from scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
copy to scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/README.md
index 69328a4..63b9f92 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/README.md
@@ -31,7 +31,7 @@ You can download the files using the script `get_ssd_data.sh`. It will download
 From the `scala-package/examples/scripts/infer/imageclassifier/` folder run:
 
 ```bash
-./get_resnet_data.sh
+./get_ssd_data.sh
 ```
 
 **Note**: You may need to run `chmod +x get_resnet_data.sh` before running this script.
@@ -79,7 +79,7 @@ After the previous steps, you should be able to run the code using the following
 From the `scala-package/examples/scripts/inferexample/objectdetector/` folder run:
 
 ```bash
-./run_ssd_example.sh ../models/resnet50_ssd_model ../images/dog.jpg ../images
+./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
 ```
 
 **Notes**:
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/SSDClassifierExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/SSDClassifierExample.java
new file mode 100644
index 0000000..13f9d2d
--- /dev/null
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/SSDClassifierExample.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.infer.javapi.objectdetector;
+
+import org.apache.mxnet.infer.javaapi.ObjectDetectorOutput;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.mxnet.javaapi.*;
+import org.apache.mxnet.infer.javaapi.ObjectDetector;
+
+// scalastyle:off
+import java.awt.image.BufferedImage;
+// scalastyle:on
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import java.io.File;
+
+public class SSDClassifierExample {
+	@Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model")
+	private String modelPathPrefix = "/model/ssd_resnet50_512";
+	@Option(name = "--input-image", usage = "the input image")
+	private String inputImagePath = "/images/dog.jpg";
+	@Option(name = "--input-dir", usage = "the input batch of images directory")
+	private String inputImageDir = "/images/";
+	
+	final static Logger logger = LoggerFactory.getLogger(SSDClassifierExample.class);
+	
+	static List<List<ObjectDetectorOutput>>
+	runObjectDetectionSingle(String modelPathPrefix, String inputImagePath, List<Context> context) {
+		Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
+		List<DataDesc> inputDescriptors = new ArrayList<DataDesc>();
+		inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+		BufferedImage img = ObjectDetector.loadImageFromFile(inputImagePath);
+		ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
+		return objDet.imageObjectDetect(img, 3);
+	}
+	
+	static List<List<List<ObjectDetectorOutput>>>
+	runObjectDetectionBatch(String modelPathPrefix, String inputImageDir, List<Context> context) {
+		Shape inputShape = new Shape(new int[]{1, 3, 512, 512});
+		List<DataDesc> inputDescriptors = new ArrayList<DataDesc>();
+		inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+		ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
+		
+		// Loading batch of images from the directory path
+		List<List<String>> batchFiles = generateBatches(inputImageDir, 20);
+		List<List<List<ObjectDetectorOutput>>> outputList
+				= new ArrayList<List<List<ObjectDetectorOutput>>>();
+		
+		for (List<String> batchFile : batchFiles) {
+			List<BufferedImage> imgList = ObjectDetector.loadInputBatch(batchFile);
+			// Running inference on batch of images loaded in previous step
+			List<List<ObjectDetectorOutput>> tmp
+					= objDet.imageBatchObjectDetect(imgList, 5);
+			outputList.add(tmp);
+		}
+		return outputList;
+	}
+	
+	static List<List<String>> generateBatches(String inputImageDirPath, int batchSize) {
+		File dir = new File(inputImageDirPath);
+
+		List<List<String>> output = new ArrayList<List<String>>();
+		List<String> batch = new ArrayList<String>();
+		for (File imgFile : dir.listFiles()) {
+			batch.add(imgFile.getPath());
+			if (batch.size() == batchSize) {
+				output.add(batch);
+				batch = new ArrayList<String>();
+			}
+		}
+		if (batch.size() > 0) {
+			output.add(batch);
+		}
+		return output;
+	}
+	
+	public static void main(String[] args) {
+		SSDClassifierExample inst = new SSDClassifierExample();
+		CmdLineParser parser  = new CmdLineParser(inst);
+		try {
+			parser.parseArgument(args);
+		} catch (Exception e) {
+			logger.error(e.getMessage(), e);
+			parser.printUsage(System.err);
+			System.exit(1);
+		}
+		
+		String mdprefixDir = inst.modelPathPrefix;
+		String imgPath = inst.inputImagePath;
+		String imgDir = inst.inputImageDir;
+		
+		if (!checkExist(Arrays.asList(mdprefixDir + "-symbol.json", imgDir, imgPath))) {
+			logger.error("Model or input image path does not exist");
+			System.exit(1);
+		}
+		
+		List<Context> context = new ArrayList<Context>();
+		if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+				Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
+			context.add(Context.gpu());
+		} else {
+			context.add(Context.cpu());
+		}
+		
+		try {
+			Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
+			Shape outputShape = new Shape(new int[] {1, 6132, 6});
+			
+			
+			int width = inputShape.get(2);
+			int height = inputShape.get(3);
+			String outputStr = "\n";
+			
+			List<List<ObjectDetectorOutput>> output
+					= runObjectDetectionSingle(mdprefixDir, imgPath, context);
+			
+			for (List<ObjectDetectorOutput> ele : output) {
+				for (ObjectDetectorOutput i : ele) {
+					outputStr += "Class: " + i.getClassName() + "\n";
+					outputStr += "Probabilties: " + i.getProbability() + "\n";
+					
+					List<Float> coord = Arrays.asList(i.getXMin() * width,
+							i.getXMax() * height, i.getYMin() * width, i.getYMax() * height);
+					StringBuilder sb = new StringBuilder();
+					for (float c: coord) {
+						sb.append(", ").append(c);
+					}
+					outputStr += "Coord:" + sb.substring(2)+ "\n";
+				}
+			}
+			logger.info(outputStr);
+			
+			List<List<List<ObjectDetectorOutput>>> outputList =
+					runObjectDetectionBatch(mdprefixDir, imgDir, context);
+			
+			outputStr = "\n";
+			int index = 0;
+			for (List<List<ObjectDetectorOutput>> i: outputList) {
+				for (List<ObjectDetectorOutput> j : i) {
+					outputStr += "*** Image " + (index + 1) + "***" + "\n";
+					for (ObjectDetectorOutput k : j) {
+						outputStr += "Class: " + k.getClassName() + "\n";
+						outputStr += "Probabilties: " + k.getProbability() + "\n";
+						List<Float> coord = Arrays.asList(k.getXMin() * width,
+								k.getXMax() * height, k.getYMin() * width, k.getYMax() * height);
+						
+						StringBuilder sb = new StringBuilder();
+						for (float c : coord) {
+							sb.append(", ").append(c);
+						}
+						outputStr += "Coord:" + sb.substring(2) + "\n";
+					}
+					index++;
+				}
+			}
+			logger.info(outputStr);
+			
+		} catch (Exception e) {
+			logger.error(e.getMessage(), e);
+			parser.printUsage(System.err);
+			System.exit(1);
+		}
+		System.exit(0);
+	}
+	
+	static Boolean checkExist(List<String> arr)  {
+		Boolean exist = true;
+		for (String item : arr) {
+			exist = new File(item).exists() && exist;
+			if (!exist) {
+				logger.error("Cannot find: " + item);
+			}
+		}
+		return exist;
+	}
+}
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
index 69328a4..bf4a44a 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/README.md
@@ -31,7 +31,7 @@ You can download the files using the script `get_ssd_data.sh`. It will download
 From the `scala-package/examples/scripts/infer/imageclassifier/` folder run:
 
 ```bash
-./get_resnet_data.sh
+./get_ssd_data.sh
 ```
 
 **Note**: You may need to run `chmod +x get_resnet_data.sh` before running this script.
@@ -79,7 +79,7 @@ After the previous steps, you should be able to run the code using the following
 From the `scala-package/examples/scripts/inferexample/objectdetector/` folder run:
 
 ```bash
-./run_ssd_example.sh ../models/resnet50_ssd_model ../images/dog.jpg ../images
+./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
 ```
 
 **Notes**:
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
new file mode 100644
index 0000000..6cd3df6
--- /dev/null
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi
+
+// scalastyle:off
+import java.awt.image.BufferedImage
+// scalastyle:on
+
+import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray}
+
+import scala.collection.JavaConverters
+import scala.collection.JavaConverters._
+
+
+class ObjectDetector(val objDetector: org.apache.mxnet.infer.ObjectDetector){
+
+  def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc], contexts:
+  java.util.List[Context], epoch: Int)
+  = this {
+    val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
+      .asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
+    val inContexts = (contexts.asScala.toList map {a => a: org.apache.mxnet.Context}).toArray
+    // scalastyle:off
+    new org.apache.mxnet.infer.ObjectDetector(modelPathPrefix, informationDesc, inContexts, Some(epoch))
+    // scalastyle:on
+  }
+
+  /**
+    * Detects objects and returns bounding boxes with corresponding class/label
+    *
+    * @param inputImage       Path prefix of the input image
+    * @param topK             Number of result elements to return, sorted by probability
+    * @return                 List of list of tuples of
+    *                         (class, [probability, xmin, ymin, xmax, ymax])
+    */
+  def imageObjectDetect(inputImage: BufferedImage, topK: Int):
+  java.util.List[java.util.List[ObjectDetectorOutput]] = {
+    val ret = objDetector.imageObjectDetect(inputImage, Some(topK))
+    (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
+  }
+
+  /**
+    * Takes input images as NDArrays. Useful when you want to perform multiple operations on
+    * the input array, or when you want to pass a batch of input images.
+    *
+    * @param input            Indexed Sequence of NDArrays
+    * @param topK             (Optional) How many top_k (sorting will be based on the last axis)
+    *                         elements to return. If not passed, returns all unsorted output.
+    * @return                 List of list of tuples of
+    *                         (class, [probability, xmin, ymin, xmax, ymax])
+    */
+  def objectDetectWithNDArray(input: java.util.List[NDArray], topK: Int):
+  java.util.List[java.util.List[ObjectDetectorOutput]] = {
+    val ret = objDetector.objectDetectWithNDArray(convert(input.asScala.toIndexedSeq), Some(topK))
+    (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
+  }
+
+  /**
+    * To classify batch of input images according to the provided model
+    *
+    * @param inputBatch       Input array of buffered images
+    * @param topK             Number of result elements to return, sorted by probability
+    * @return                 List of list of tuples of (class, probability)
+    */
+  def imageBatchObjectDetect(inputBatch: java.util.List[BufferedImage], topK: Int):
+      java.util.List[java.util.List[ObjectDetectorOutput]] = {
+    val ret = objDetector.imageBatchObjectDetect(inputBatch.asScala, Some(topK))
+    (ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
+  }
+
+  def convert[B, A <% B](l: IndexedSeq[A]): IndexedSeq[B] = l map { a => a: B }
+
+}
+
+
+object ObjectDetector {
+  implicit def fromObjectDetector(OD: org.apache.mxnet.infer.ObjectDetector):
+    ObjectDetector = new ObjectDetector(OD)
+
+  implicit def toObjectDetector(jOD: ObjectDetector):
+    org.apache.mxnet.infer.ObjectDetector = jOD.objDetector
+
+  def loadImageFromFile(inputImagePath: String): BufferedImage = {
+    org.apache.mxnet.infer.ImageClassifier.loadImageFromFile(inputImagePath)
+  }
+
+  def loadInputBatch(inputImagePaths: java.util.List[String]): java.util.List[BufferedImage] = {
+    org.apache.mxnet.infer.ImageClassifier
+      .loadInputBatch(inputImagePaths.asScala.toList).toList.asJava
+  }
+}
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala
new file mode 100644
index 0000000..13369c8
--- /dev/null
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi
+
+class ObjectDetectorOutput (className: String, args: Array[Float]){
+
+  def getClassName: String = className
+
+  def getProbability: Float = args(0)
+
+  def getXMin: Float = args(1)
+
+  def getXMax: Float = args(2)
+
+  def getYMin: Float = args(3)
+
+  def getYMax: Float = args(4)
+
+}
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
new file mode 100644
index 0000000..26ccd06
--- /dev/null
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi
+
+import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray}
+
+import scala.collection.JavaConverters
+import scala.collection.JavaConverters._
+
+class Predictor(val predictor: org.apache.mxnet.infer.Predictor){
+  def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc],
+           contexts: java.util.List[Context], epoch: Int)
+  = this {
+    val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
+      .asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
+    val inContexts = (contexts.asScala.toList map {a => a: org.apache.mxnet.Context}).toArray
+    new org.apache.mxnet.infer.Predictor(modelPathPrefix, informationDesc, inContexts, Some(epoch))
+  }
+
+
+  /**
+    * Takes input as List of one dimensional arrays and creates the NDArray needed for inference
+    * The array will be reshaped based on the input descriptors.
+    *
+    * @param input:            A List of a one-dimensional array.
+                              An extra List is needed for when the model has more than one input.
+    * @return                  Indexed sequence array of outputs
+    */
+  def predict(input: java.util.List[java.util.List[Float]]):
+  java.util.List[java.util.List[Float]] = {
+    val in = JavaConverters.asScalaIteratorConverter(input.iterator).asScala.toIndexedSeq
+    (predictor.predict(in map {a => a.asScala.toArray}) map {b => b.toList.asJava}).asJava
+  }
+
+
+  /**
+    * Predict using NDArray as input
+    * This method is useful when the input is a batch of data
+    * Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
+    *
+    * @param input       List of NDArrays
+    * @return                  Output of predictions as NDArrays
+    */
+  def predictWithNDArray(input: java.util.List[NDArray]):
+  java.util.List[NDArray] = {
+    val ret = predictor.predictWithNDArray(convert(JavaConverters
+      .asScalaIteratorConverter(input.iterator).asScala.toIndexedSeq))
+    // TODO: For some reason the implicit wasn't working here when trying to use convert.
+    // So did it this way. Needs to be figured out
+    (ret map {a => new NDArray(a)}).asJava
+  }
+
+  private def convert[B, A <% B](l: IndexedSeq[A]): IndexedSeq[B] = l map { a => a: B }
+}