You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2019/03/26 18:39:13 UTC

[incubator-mxnet] branch master updated: [MXNET-1285] Draw bounding box with Scala/Java Image API (#14474)

This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 092af36  [MXNET-1285] Draw bounding box with Scala/Java Image API (#14474)
092af36 is described below

commit 092af3601bba6a24201154c3185bd3d9f39677f7
Author: Lanking <la...@live.com>
AuthorDate: Tue Mar 26 11:38:40 2019 -0700

    [MXNET-1285] Draw bounding box with Scala/Java Image API (#14474)
    
    * new feature to draw bounding box
    
    * add Java support
    
    * add point wise verification
    
    * cancel the check on top-left corner
    
    * add this example to Java world and fixing bugs
---
 scala-package/.gitignore                           |  1 +
 .../src/main/scala/org/apache/mxnet/Image.scala    | 54 ++++++++++++++++++++++
 .../scala/org/apache/mxnet/javaapi/Image.scala     | 46 ++++++++++++------
 .../java/org/apache/mxnet/javaapi/ImageTest.java   | 20 +++++++-
 .../test/scala/org/apache/mxnet/ImageSuite.scala   | 21 +++++++++
 .../javaapi/infer/objectdetector/README.md         |  2 +-
 .../infer/objectdetector/SSDClassifierExample.java | 41 ++++++++++------
 .../org/apache/mxnet/infer/ObjectDetector.scala    |  2 +-
 .../mxnet/infer/javaapi/ObjectDetectorOutput.scala |  4 +-
 .../infer/javaapi/ObjectDetectorOutputTest.java    |  4 +-
 10 files changed, 161 insertions(+), 34 deletions(-)

diff --git a/scala-package/.gitignore b/scala-package/.gitignore
index 9bf7851..dadc000 100644
--- a/scala-package/.gitignore
+++ b/scala-package/.gitignore
@@ -9,3 +9,4 @@ core/src/main/scala/org/apache/mxnet/SymbolBase.scala
 core/src/main/scala/org/apache/mxnet/SymbolRandomAPIBase.scala
 examples/scripts/infer/images/
 examples/scripts/infer/models/
+examples/scripts/infer/objectdetector/boundingImage.png
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
index 0f756e2..52e26ef 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
@@ -17,6 +17,7 @@
 
 package org.apache.mxnet
 // scalastyle:off
+import java.awt.{BasicStroke, Color, Graphics2D}
 import java.awt.image.BufferedImage
 // scalastyle:on
 import java.io.InputStream
@@ -182,4 +183,57 @@ object Image {
     img
   }
 
+  /**
+    * Helper function to generate ramdom colors
+    * @param transparency The transparency level
+    * @return Color
+    */
+  private def randomColor(transparency: Option[Float] = Some(1.0f)) : Color = {
+    new Color(
+      Math.random().toFloat, Math.random().toFloat, Math.random().toFloat,
+      transparency.get
+    )
+  }
+
+  /**
+    * Method to draw bounding boxes for an image
+    * @param src Source of the buffered image
+    * @param coordinate Contains Map of xmin, xmax, ymin, ymax
+    *                   corresponding to top-left and down-right points
+    * @param names The name set of the bounding box
+    * @param stroke Thickness of the bounding box
+    * @param fontSizeMult Font size multiplier
+    * @param transparency Transparency of the bounding box
+    */
+  def drawBoundingBox(src: BufferedImage, coordinate: Array[Map[String, Int]],
+                      names: Option[Array[String]] = None,
+                      stroke : Option[Int] = Some(3),
+                      fontSizeMult : Option[Float] = Some(1.0f),
+                      transparency: Option[Float] = Some(1.0f)): Unit = {
+    val g2d : Graphics2D = src.createGraphics()
+    g2d.setStroke(new BasicStroke(stroke.get))
+    // Increase the size of font
+    val currentFont = g2d.getFont
+    val newFont = currentFont.deriveFont(currentFont.getSize * fontSizeMult.get)
+    g2d.setFont(newFont)
+    // Get font metrics to draw the font box
+    val fm = g2d.getFontMetrics(newFont)
+    for (idx <- coordinate.indices) {
+      val map = coordinate(idx)
+      g2d.setColor(randomColor(transparency).darker())
+      g2d.drawRect(map("xmin"), map("ymin"), map("xmax") - map("xmin"), map("ymax") - map("ymin"))
+      // Write the name of the bounding box
+      if (names.isDefined) {
+        val x = map("xmin") - stroke.get
+        val y = map("ymin")
+        val h = fm.getHeight
+        val w = fm.charsWidth(names.get(idx).toCharArray, 0, names.get(idx).length())
+        g2d.fillRect(x, y - h, w, h)
+        g2d.setColor(Color.WHITE)
+        g2d.drawString(names.get(idx), x, y)
+      }
+    }
+    g2d.dispose()
+  }
+
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala
index 7d6f31e..f72223d 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala
@@ -20,15 +20,16 @@ package org.apache.mxnet.javaapi
 import java.awt.image.BufferedImage
 // scalastyle:on
 import java.io.InputStream
+import scala.collection.JavaConverters._
 
 object Image {
   /**
     * Decode image with OpenCV.
     * Note: return image in RGB by default, instead of OpenCV's default BGR.
-    * @param buf    Buffer containing binary encoded image
-    * @param flag   Convert decoded image to grayscale (0) or color (1).
+    * @param buf   Buffer containing binary encoded image
+    * @param flag  Convert decoded image to grayscale (0) or color (1).
     * @param toRGB Whether to convert decoded image
-    *               to mxnet's default RGB format (instead of opencv's default BGR).
+    *              to mxnet's default RGB format (instead of opencv's default BGR).
     * @return NDArray in HWC format with DType [[DType.UInt8]]
     */
   def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = {
@@ -43,8 +44,8 @@ object Image {
     * Same imageDecode with InputStream
     *
     * @param inputStream the inputStream of the image
-    * @param flag   Convert decoded image to grayscale (0) or color (1).
-    * @param toRGB Whether to convert decoded image
+    * @param flag        Convert decoded image to grayscale (0) or color (1).
+    * @param toRGB       Whether to convert decoded image
     * @return NDArray in HWC format with DType [[DType.UInt8]]
     */
   def imDecode(inputStream: InputStream, flag: Int, toRGB: Boolean): NDArray = {
@@ -60,7 +61,7 @@ object Image {
     * Note: return image in RGB by default, instead of OpenCV's default BGR.
     * @param filename Name of the image file to be loaded.
     * @param flag     Convert decoded image to grayscale (0) or color (1).
-    * @param toRGB   Whether to convert decoded image to mxnet's default RGB format
+    * @param toRGB    Whether to convert decoded image to mxnet's default RGB format
     *                 (instead of opencv's default BGR).
     * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
     */
@@ -74,10 +75,10 @@ object Image {
 
   /**
     * Resize image with OpenCV.
-    * @param src     source image in NDArray
-    * @param w       Width of resized image.
-    * @param h       Height of resized image.
-    * @param interp  Interpolation method (default=cv2.INTER_LINEAR).
+    * @param src    source image in NDArray
+    * @param w      Width of resized image.
+    * @param h      Height of resized image.
+    * @param interp Interpolation method (default=cv2.INTER_LINEAR).
     * @return org.apache.mxnet.NDArray
     */
   def imResize(src: NDArray, w: Int, h: Int, interp: Integer): NDArray = {
@@ -92,10 +93,10 @@ object Image {
   /**
     * Do a fixed crop on the image
     * @param src Src image in NDArray
-    * @param x0 starting x point
-    * @param y0 starting y point
-    * @param w width of the image
-    * @param h height of the image
+    * @param x0  starting x point
+    * @param y0  starting y point
+    * @param w   width of the image
+    * @param h   height of the image
     * @return cropped NDArray
     */
   def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = {
@@ -111,4 +112,21 @@ object Image {
   def toImage(src: NDArray): BufferedImage = {
     org.apache.mxnet.Image.toImage(src)
   }
+
+  /**
+    * Draw bounding boxes on the image
+    * @param src        buffered image to draw on
+    * @param coordinate Contains Map of xmin, xmax, ymin, ymax
+    *                   corresponding to top-left and down-right points
+    * @param names      The name set of the bounding box
+    */
+  def drawBoundingBox(src: BufferedImage,
+                      coordinate: java.util.List[
+                        java.util.Map[java.lang.String, java.lang.Integer]],
+                      names: java.util.List[java.lang.String]): Unit = {
+    val coord = coordinate.asScala.map(
+      _.asScala.map{case (name, value) => (name, Integer2int(value))}.toMap).toArray
+    org.apache.mxnet.Image.drawBoundingBox(src, coord, Option(names.asScala.toArray))
+  }
+
 }
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java
index 0092744..f5515dc 100644
--- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java
@@ -20,8 +20,15 @@ package org.apache.mxnet.javaapi;
 import org.apache.commons.io.FileUtils;
 import org.junit.BeforeClass;
 import org.junit.Test;
+
+import javax.imageio.ImageIO;
+import java.awt.image.BufferedImage;
 import java.io.File;
 import java.net.URL;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
 
 import static org.junit.Assert.assertArrayEquals;
 
@@ -56,12 +63,23 @@ public class ImageTest {
     }
 
     @Test
-    public void testImageProcess() {
+    public void testImageProcess() throws Exception {
         NDArray nd = Image.imRead(imLocation, 1, true);
         assertArrayEquals(nd.shape().toArray(), new int[]{576, 1024, 3});
         NDArray nd2 = Image.imResize(nd, 224, 224, null);
         assertArrayEquals(nd2.shape().toArray(), new int[]{224, 224, 3});
         NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224);
         Image.toImage(cropped);
+        BufferedImage buf = ImageIO.read(new File(imLocation));
+        Map<String, Integer> map = new HashMap<>();
+        map.put("xmin", 190);
+        map.put("xmax", 850);
+        map.put("ymin", 50);
+        map.put("ymax", 450);
+        List<Map<String, Integer>> box = new ArrayList<>();
+        box.add(map);
+        List<String> names = new ArrayList<>();
+        names.add("pug");
+        Image.drawBoundingBox(buf, box, names);
     }
 }
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala
index 67815ad..d4cf35a 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala
@@ -97,4 +97,25 @@ class ImageSuite extends FunSuite with BeforeAndAfterAll {
     logger.info(s"converted image stored in ${tempDirPath + "/inputImages/out.png"}")
   }
 
+  test("Test draw Bounding box") {
+    val buf = ImageIO.read(new File(imLocation))
+    val box = Array(
+      Map("xmin" -> 190, "xmax" -> 850, "ymin" -> 50, "ymax" -> 450),
+      Map("xmin" -> 200, "xmax" -> 350, "ymin" -> 440, "ymax" -> 530)
+    )
+    val names = Array("pug", "cookie")
+    Image.drawBoundingBox(buf, box, Some(names), fontSizeMult = Some(1.4f))
+    val tempDirPath = System.getProperty("java.io.tmpdir")
+    ImageIO.write(buf, "png", new File(tempDirPath + "/inputImages/out2.png"))
+    logger.info(s"converted image stored in ${tempDirPath + "/inputImages/out2.png"}")
+    for (coord <- box) {
+      val topLeft = buf.getRGB(coord("xmin"), coord("ymin"))
+      val downLeft = buf.getRGB(coord("xmin"), coord("ymax"))
+      val topRight = buf.getRGB(coord("xmax"), coord("ymin"))
+      val downRight = buf.getRGB(coord("xmax"), coord("ymax"))
+      require(downLeft == downRight)
+      require(topRight == downRight)
+    }
+  }
+
 }
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md
index 8a9ed3e..4c4512f 100644
--- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md
@@ -84,7 +84,7 @@ After the previous steps, you should be able to run the code using the following
 From the `scala-package/examples/scripts/infer/objectdetector/` folder run:
 
 ```bash
-./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
+./run_ssd_java_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
 ```
 
 **Notes**:
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java
index a9c00f7..31b8514 100644
--- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java
@@ -28,12 +28,11 @@ import org.apache.mxnet.javaapi.*;
 import org.apache.mxnet.infer.javaapi.ObjectDetector;
 
 // scalastyle:off
+import javax.imageio.ImageIO;
 import java.awt.image.BufferedImage;
 // scalastyle:on
 
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
+import java.util.*;
 
 import java.io.File;
 
@@ -128,22 +127,34 @@ public class SSDClassifierExample {
         try {
             Shape inputShape = new Shape(new int[]{1, 3, 512, 512});
             Shape outputShape = new Shape(new int[]{1, 6132, 6});
-            
-            
-            int width = inputShape.get(2);
-            int height = inputShape.get(3);
+
             StringBuilder outputStr = new StringBuilder().append("\n");
             
             List<List<ObjectDetectorOutput>> output
                     = runObjectDetectionSingle(mdprefixDir, imgPath, context);
-            
+
+            // Creating Bounding box material
+            BufferedImage buf = ImageIO.read(new File(imgPath));
+            int width = buf.getWidth();
+            int height = buf.getHeight();
+            List<Map<String, Integer>> boxes = new ArrayList<>();
+            List<String> names = new ArrayList<>();
             for (List<ObjectDetectorOutput> ele : output) {
                 for (ObjectDetectorOutput i : ele) {
                     outputStr.append("Class: " + i.getClassName() + "\n");
                     outputStr.append("Probabilties: " + i.getProbability() + "\n");
-                    
-                    List<Float> coord = Arrays.asList(i.getXMin() * width,
-                            i.getXMax() * height, i.getYMin() * width, i.getYMax() * height);
+                    names.add(i.getClassName());
+                    Map<String, Integer> map = new HashMap<>();
+                    float xmin = i.getXMin() * width;
+                    float xmax = i.getXMax() * width;
+                    float ymin = i.getYMin() * height;
+                    float ymax = i.getYMax() * height;
+                    List<Float> coord = Arrays.asList(xmin, xmax, ymin, ymax);
+                    map.put("xmin", (int) xmin);
+                    map.put("xmax", (int) xmax);
+                    map.put("ymin", (int) ymin);
+                    map.put("ymax", (int) ymax);
+                    boxes.add(map);
                     StringBuilder sb = new StringBuilder();
                     for (float c : coord) {
                         sb.append(", ").append(c);
@@ -152,7 +163,12 @@ public class SSDClassifierExample {
                 }
             }
             logger.info(outputStr.toString());
-            
+
+            // Covert to image
+            Image.drawBoundingBox(buf, boxes, names);
+            File outputFile = new File("boundingImage.png");
+            ImageIO.write(buf, "png", outputFile);
+
             List<List<List<ObjectDetectorOutput>>> outputList =
                     runObjectDetectionBatch(mdprefixDir, imgDir, context);
             
@@ -177,7 +193,6 @@ public class SSDClassifierExample {
                 }
             }
             logger.info(outputStr.toString());
-            
         } catch (Exception e) {
             logger.error(e.getMessage(), e);
             parser.printUsage(System.err);
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala
index e29f068..28a578c 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala
@@ -132,7 +132,7 @@ class ObjectDetector(modelPathPrefix: String,
     if (topK.isDefined) {
       var sortedIndices = predictResult.zipWithIndex.sortBy(-_._1(1)).map(_._2)
       sortedIndices = sortedIndices.take(topK.get)
-      // takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax
+      // takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax]
       result = sortedIndices.map(idx
       => (synset(predictResult(idx)(0).toInt),
           predictResult(idx).takeRight(5))).toIndexedSeq
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala
index 5a6ac75..32fd87e 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala
@@ -52,14 +52,14 @@ class ObjectDetectorOutput (className: String, args: Array[Float]){
     *
     * @return       Float of the max X coordinate for the object bounding box
     */
-  def getXMax: Float = args(2)
+  def getXMax: Float = args(3)
 
   /**
     * Gets the minimum Y coordinate for the bounding box containing the predicted object.
     *
     * @return       Float of the min Y coordinate for the object bounding box
     */
-  def getYMin: Float = args(3)
+  def getYMin: Float = args(2)
 
   /**
     * Gets the maximum Y coordinate for the bounding box containing the predicted object.
diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java
index 04041fc..6f3df86 100644
--- a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java
+++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java
@@ -36,8 +36,8 @@ public class ObjectDetectorOutputTest {
         Assert.assertEquals(odOutput.getClassName(), predictedClassName);
         Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta);
         Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta);
-        Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 2f, delta);
-        Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 3f, delta);
+        Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 3f, delta);
+        Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 2f, delta);
         Assert.assertEquals("Threshold not matching", odOutput.getYMax(), 4f, delta);
 
     }