You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by aa...@apache.org on 2019/02/28 21:14:40 UTC

[incubator-mxnet] branch master updated: MXNet Java bug fixes and experience improvement (#14213)

This is an automated email from the ASF dual-hosted git repository.

aaronmarkham pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new c319ae5  MXNet Java bug fixes and experience improvement (#14213)
c319ae5 is described below

commit c319ae570c1cfdd01e0349dbdb77d7774e28c98e
Author: Lanking <la...@live.com>
AuthorDate: Thu Feb 28 13:14:19 2019 -0800

    MXNet Java bug fixes and experience improvement (#14213)
    
    * improve Java user experience
    
    * add the new examples
    
    * fixed based on the comments
---
 .../scala/org/apache/mxnet/javaapi/NDArray.scala   |   2 +
 scala-package/mxnet-demo/java-demo/README.md       |  18 ++-
 .../mxnet-demo/java-demo/bin/java_sample.sh        |   2 +-
 scala-package/mxnet-demo/java-demo/bin/run_od.sh   |   2 +-
 scala-package/mxnet-demo/java-demo/pom.xml         |   6 +
 .../src/main/java/mxnet/ImageClassification.java   | 131 +++++++++++++++++++++
 .../{HelloWorld.java => NDArrayCreation.java}      |  29 +++--
 .../{HelloWorld.java => NDArrayOperation.java}     |  26 ++--
 .../src/main/java/mxnet/ObjectDetection.java       |  16 ++-
 9 files changed, 201 insertions(+), 31 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
index 67809c1..50139ec 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -242,6 +242,8 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
     this(NDArray.array(arr, shape, ctx))
   }
 
+  override def toString: String = nd.toString
+
   def serialize(): Array[Byte] = nd.serialize()
 
   /**
diff --git a/scala-package/mxnet-demo/java-demo/README.md b/scala-package/mxnet-demo/java-demo/README.md
index 7602193..cad52cb 100644
--- a/scala-package/mxnet-demo/java-demo/README.md
+++ b/scala-package/mxnet-demo/java-demo/README.md
@@ -16,9 +16,15 @@
 <!--- under the License. -->
 
 # MXNet Java Sample Project
-This is an project created to use Maven-published Scala/Java package with two Java examples.
+This is a project demonstrating how to use the Maven published Scala/Java MXNet package. 
+The examples provided include:
+* NDArray creation
+* NDArray operation
+* Object Detection using the Inference API
+* Image Classification using the Predictor API
+
 ## Setup
-You are required to use Maven to build the package with the following commands:
+You are required to use Maven to build the package with the following commands under `java-demo`:
 ```
 mvn package
 ```
@@ -42,16 +48,16 @@ The `SCALA_PKG_PROFILE` should be chosen from `osx-x86_64-cpu`, `linux-x86_64-cp
 
 
 ## Run
-### Hello World
-The Scala file is being executed using Java. You can execute the helloWorld example as follows:
+### NDArrayCreation
+The Scala file is being executed using Java. You can execute the `NDArrayCreation` example as follows:
 ```Bash
 bash bin/java_sample.sh
 ```
 You can also run the following command manually:
 ```Bash
-java -cp $CLASSPATH sample.HelloWorld
+java -cp $CLASSPATH sample.NDArrayCreation
 ```
-However, you have to define the Classpath before you run the demo code. More information can be found in the `java_sample.sh`.
+However, you have to define the Classpath before you run the demo code. More information can be found in `bin/java_sample.sh`.
 The `CLASSPATH` should point to the jar file you have downloaded.
 
 It will load the library automatically and run the example
diff --git a/scala-package/mxnet-demo/java-demo/bin/java_sample.sh b/scala-package/mxnet-demo/java-demo/bin/java_sample.sh
index 4fb724a..fb1795f 100755
--- a/scala-package/mxnet-demo/java-demo/bin/java_sample.sh
+++ b/scala-package/mxnet-demo/java-demo/bin/java_sample.sh
@@ -17,4 +17,4 @@
 #!/bin/bash
 CURR_DIR=$(cd $(dirname $0)/../; pwd)
 CLASSPATH=$CLASSPATH:$CURR_DIR/target/*:$CLASSPATH:$CURR_DIR/target/dependency/*
-java -Xmx8G  -cp $CLASSPATH mxnet.HelloWorld
\ No newline at end of file
+java -Xmx8G  -cp $CLASSPATH mxnet.NDArrayCreation
diff --git a/scala-package/mxnet-demo/java-demo/bin/run_od.sh b/scala-package/mxnet-demo/java-demo/bin/run_od.sh
index abd0bf5..4370518 100755
--- a/scala-package/mxnet-demo/java-demo/bin/run_od.sh
+++ b/scala-package/mxnet-demo/java-demo/bin/run_od.sh
@@ -17,4 +17,4 @@
 #!/bin/bash
 CURR_DIR=$(cd $(dirname $0)/../; pwd)
 CLASSPATH=$CLASSPATH:$CURR_DIR/target/*:$CLASSPATH:$CURR_DIR/target/dependency/*
-java -Xmx8G  -cp $CLASSPATH mxnet.ObjectDetection
\ No newline at end of file
+java -Xmx8G  -cp $CLASSPATH mxnet.ObjectDetection
diff --git a/scala-package/mxnet-demo/java-demo/pom.xml b/scala-package/mxnet-demo/java-demo/pom.xml
index 39253b1..eb5e043 100644
--- a/scala-package/mxnet-demo/java-demo/pom.xml
+++ b/scala-package/mxnet-demo/java-demo/pom.xml
@@ -83,6 +83,12 @@
             <version>${mxnet.version}</version>
         </dependency>
         <dependency>
+            <groupId>org.apache.mxnet</groupId>
+            <artifactId>mxnet-full_${mxnet.scalaprofile}-${mxnet.profile}</artifactId>
+            <version>${mxnet.version}</version>
+            <classifier>sources</classifier>
+        </dependency>
+        <dependency>
             <groupId>commons-io</groupId>
             <artifactId>commons-io</artifactId>
             <version>2.4</version>
diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ImageClassification.java b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ImageClassification.java
new file mode 100644
index 0000000..8cb58da
--- /dev/null
+++ b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ImageClassification.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package mxnet;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.mxnet.infer.javaapi.Predictor;
+import org.apache.mxnet.javaapi.*;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileReader;
+import java.io.IOException;
+import java.net.URL;
+import java.util.ArrayList;
+import java.util.List;
+
+public class ImageClassification {
+    private static String modelPath;
+    private static String imagePath;
+
+    private static void downloadUrl(String url, String filePath) {
+        File tmpFile = new File(filePath);
+        if (!tmpFile.exists()) {
+            try {
+                FileUtils.copyURLToFile(new URL(url), tmpFile);
+            } catch (Exception exception) {
+                System.err.println(exception);
+            }
+        }
+    }
+
+    public static void downloadModelImage() {
+        String tempDirPath = System.getProperty("java.io.tmpdir");
+        String baseUrl = "https://s3.us-east-2.amazonaws.com/scala-infer-models";
+        downloadUrl(baseUrl + "/resnet-18/resnet-18-symbol.json",
+                tempDirPath + "/resnet18/resnet-18-symbol.json");
+        downloadUrl(baseUrl + "/resnet-18/resnet-18-0000.params",
+                tempDirPath + "/resnet18/resnet-18-0000.params");
+        downloadUrl(baseUrl + "/resnet-18/synset.txt",
+                tempDirPath + "/resnet18/synset.txt");
+        downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
+                tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg");
+        modelPath = tempDirPath + File.separator + "resnet18/resnet-18";
+        imagePath = tempDirPath + File.separator +
+                "inputImages/resnet18/Pug-Cookie.jpg";
+    }
+
+    /**
+     * Helper class to print the maximum prediction result
+     * @param probabilities The float array of probability
+     * @param modelPathPrefix model Path needs to load the synset.txt
+     */
+    private static String printMaximumClass(float[] probabilities,
+                                            String modelPathPrefix) throws IOException {
+        String synsetFilePath = modelPathPrefix.substring(0,
+                1 + modelPathPrefix.lastIndexOf(File.separator)) + "/synset.txt";
+        BufferedReader reader = new BufferedReader(new FileReader(synsetFilePath));
+        ArrayList<String> list = new ArrayList<>();
+        String line = reader.readLine();
+
+        while (line != null){
+            list.add(line);
+            line = reader.readLine();
+        }
+        reader.close();
+
+        int maxIdx = 0;
+        for (int i = 1;i<probabilities.length;i++) {
+            if (probabilities[i] > probabilities[maxIdx]) {
+                maxIdx = i;
+            }
+        }
+
+        return "Probability : " + probabilities[maxIdx] + " Class : " + list.get(maxIdx) ;
+    }
+
+    public static void main(String[] args) {
+        // Download the model and Image
+        downloadModelImage();
+
+        // Prepare the model
+        List<Context> context = new ArrayList<Context>();
+        context.add(Context.cpu());
+        List<DataDesc> inputDesc = new ArrayList<>();
+        Shape inputShape = new Shape(new int[]{1, 3, 224, 224});
+        inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+        Predictor predictor = new Predictor(modelPath, inputDesc, context,0);
+
+        // Prepare data
+        NDArray nd = Image.imRead(imagePath, 1, true);
+        nd = Image.imResize(nd, 224, 224, null);
+        nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0];  // HWC to CHW
+        nd = NDArray.expand_dims(nd, 0, null)[0]; // Add N -> NCHW
+        nd = nd.asType(DType.Float32()); // Inference with Float32
+
+        // Predict directly
+        float[][] result = predictor.predict(new float[][]{nd.toArray()});
+        try {
+            System.out.println("Predict with Float input");
+            System.out.println(printMaximumClass(result[0], modelPath));
+        } catch (IOException e) {
+            System.err.println(e);
+        }
+
+        // predict with NDArray
+        List<NDArray> ndList = new ArrayList<>();
+        ndList.add(nd);
+        List<NDArray> ndResult = predictor.predictWithNDArray(ndList);
+        try {
+            System.out.println("Predict with NDArray");
+            System.out.println(printMaximumClass(ndResult.get(0).toArray(), modelPath));
+        } catch (IOException e) {
+            System.err.println(e);
+        }
+    }
+}
diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/HelloWorld.java b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayCreation.java
similarity index 59%
copy from scala-package/mxnet-demo/java-demo/src/main/java/mxnet/HelloWorld.java
copy to scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayCreation.java
index 71981e2..32e2d84 100644
--- a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/HelloWorld.java
+++ b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayCreation.java
@@ -14,19 +14,34 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+
 package mxnet;
 
 import org.apache.mxnet.javaapi.*;
-import java.util.Arrays;
 
-public class HelloWorld {
+public class NDArrayCreation {
     static NDArray$ NDArray = NDArray$.MODULE$;
-
     public static void main(String[] args) {
-    	System.out.println("Hello World!");
+
+        // Create new NDArray
         NDArray nd = new NDArray(new float[]{2.0f, 3.0f}, new Shape(new int[]{1, 2}), Context.cpu());
-        System.out.println(nd.shape());
-        NDArray nd2 = NDArray.dot(NDArray.new dotParam(nd, nd.T()))[0];
-        System.out.println(Arrays.toString(nd2.toArray()));
+        System.out.println(nd);
+
+        // create new Double NDArray
+        NDArray ndDouble = new NDArray(new double[]{2.0d, 3.0d}, new Shape(new int[]{2, 1}), Context.cpu());
+        System.out.println(ndDouble);
+
+        // create ones
+        NDArray ones = NDArray.ones(Context.cpu(), new int[] {1, 2, 3});
+        System.out.println(ones);
+
+        // random
+        NDArray random = NDArray.random_uniform(
+                NDArray.new random_uniformParam()
+                        .setLow(0.0f)
+                        .setHigh(2.0f)
+                        .setShape(new Shape(new int[]{10, 10}))
+        )[0];
+        System.out.println(random);
     }
 }
diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/HelloWorld.java b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayOperation.java
similarity index 67%
rename from scala-package/mxnet-demo/java-demo/src/main/java/mxnet/HelloWorld.java
rename to scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayOperation.java
index 71981e2..56a4143 100644
--- a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/HelloWorld.java
+++ b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayOperation.java
@@ -14,19 +14,31 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+
 package mxnet;
 
 import org.apache.mxnet.javaapi.*;
-import java.util.Arrays;
 
-public class HelloWorld {
+public class NDArrayOperation {
     static NDArray$ NDArray = NDArray$.MODULE$;
-
     public static void main(String[] args) {
-    	System.out.println("Hello World!");
         NDArray nd = new NDArray(new float[]{2.0f, 3.0f}, new Shape(new int[]{1, 2}), Context.cpu());
-        System.out.println(nd.shape());
-        NDArray nd2 = NDArray.dot(NDArray.new dotParam(nd, nd.T()))[0];
-        System.out.println(Arrays.toString(nd2.toArray()));
+
+        // Transpose
+        NDArray ndT = nd.T();
+        System.out.println(nd);
+        System.out.println(ndT);
+
+        // change Data Type
+        NDArray ndInt = nd.asType(DType.Int32());
+        System.out.println(ndInt);
+
+        // element add
+        NDArray eleAdd = NDArray.elemwise_add(nd, nd, null)[0];
+        System.out.println(eleAdd);
+
+        // norm (L2 Norm)
+        NDArray normed = NDArray.norm(NDArray.new normParam(nd))[0];
+        System.out.println(normed);
     }
 }
diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ObjectDetection.java b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ObjectDetection.java
index cfe9b66..65fe286 100644
--- a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ObjectDetection.java
+++ b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ObjectDetection.java
@@ -68,20 +68,18 @@ public class ObjectDetection {
 
     public static void main(String[] args) {
         List<Context> context = new ArrayList<Context>();
-        if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
-                Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
-            context.add(Context.gpu());
-        } else {
-            context.add(Context.cpu());
-        }
+        context.add(Context.cpu());
         downloadModelImage();
+
+        List<List<ObjectDetectorOutput>> output
+                = runObjectDetectionSingle(modelPath, imagePath, context);
+        
         Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
         Shape outputShape = new Shape(new int[] {1, 6132, 6});
         int width = inputShape.get(2);
         int height = inputShape.get(3);
-        List<List<ObjectDetectorOutput>> output
-                = runObjectDetectionSingle(modelPath, imagePath, context);
         String outputStr = "\n";
+        
         for (List<ObjectDetectorOutput> ele : output) {
             for (ObjectDetectorOutput i : ele) {
                 outputStr += "Class: " + i.getClassName() + "\n";
@@ -98,4 +96,4 @@ public class ObjectDetection {
         }
         System.out.println(outputStr);
     }
-}
\ No newline at end of file
+}