You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2019/03/26 18:39:13 UTC
[incubator-mxnet] branch master updated: [MXNET-1285] Draw bounding
box with Scala/Java Image API (#14474)
This is an automated email from the ASF dual-hosted git repository.
lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 092af36 [MXNET-1285] Draw bounding box with Scala/Java Image API (#14474)
092af36 is described below
commit 092af3601bba6a24201154c3185bd3d9f39677f7
Author: Lanking <la...@live.com>
AuthorDate: Tue Mar 26 11:38:40 2019 -0700
[MXNET-1285] Draw bounding box with Scala/Java Image API (#14474)
* new feature to draw bounding box
* add Java support
* add point wise verification
* cancel the check on top-left corner
* add this example to Java world and fixing bugs
---
scala-package/.gitignore | 1 +
.../src/main/scala/org/apache/mxnet/Image.scala | 54 ++++++++++++++++++++++
.../scala/org/apache/mxnet/javaapi/Image.scala | 46 ++++++++++++------
.../java/org/apache/mxnet/javaapi/ImageTest.java | 20 +++++++-
.../test/scala/org/apache/mxnet/ImageSuite.scala | 21 +++++++++
.../javaapi/infer/objectdetector/README.md | 2 +-
.../infer/objectdetector/SSDClassifierExample.java | 41 ++++++++++------
.../org/apache/mxnet/infer/ObjectDetector.scala | 2 +-
.../mxnet/infer/javaapi/ObjectDetectorOutput.scala | 4 +-
.../infer/javaapi/ObjectDetectorOutputTest.java | 4 +-
10 files changed, 161 insertions(+), 34 deletions(-)
diff --git a/scala-package/.gitignore b/scala-package/.gitignore
index 9bf7851..dadc000 100644
--- a/scala-package/.gitignore
+++ b/scala-package/.gitignore
@@ -9,3 +9,4 @@ core/src/main/scala/org/apache/mxnet/SymbolBase.scala
core/src/main/scala/org/apache/mxnet/SymbolRandomAPIBase.scala
examples/scripts/infer/images/
examples/scripts/infer/models/
+examples/scripts/infer/objectdetector/boundingImage.png
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
index 0f756e2..52e26ef 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
@@ -17,6 +17,7 @@
package org.apache.mxnet
// scalastyle:off
+import java.awt.{BasicStroke, Color, Graphics2D}
import java.awt.image.BufferedImage
// scalastyle:on
import java.io.InputStream
@@ -182,4 +183,57 @@ object Image {
img
}
+ /**
+ * Helper function to generate ramdom colors
+ * @param transparency The transparency level
+ * @return Color
+ */
+ private def randomColor(transparency: Option[Float] = Some(1.0f)) : Color = {
+ new Color(
+ Math.random().toFloat, Math.random().toFloat, Math.random().toFloat,
+ transparency.get
+ )
+ }
+
+ /**
+ * Method to draw bounding boxes for an image
+ * @param src Source of the buffered image
+ * @param coordinate Contains Map of xmin, xmax, ymin, ymax
+ * corresponding to top-left and down-right points
+ * @param names The name set of the bounding box
+ * @param stroke Thickness of the bounding box
+ * @param fontSizeMult Font size multiplier
+ * @param transparency Transparency of the bounding box
+ */
+ def drawBoundingBox(src: BufferedImage, coordinate: Array[Map[String, Int]],
+ names: Option[Array[String]] = None,
+ stroke : Option[Int] = Some(3),
+ fontSizeMult : Option[Float] = Some(1.0f),
+ transparency: Option[Float] = Some(1.0f)): Unit = {
+ val g2d : Graphics2D = src.createGraphics()
+ g2d.setStroke(new BasicStroke(stroke.get))
+ // Increase the size of font
+ val currentFont = g2d.getFont
+ val newFont = currentFont.deriveFont(currentFont.getSize * fontSizeMult.get)
+ g2d.setFont(newFont)
+ // Get font metrics to draw the font box
+ val fm = g2d.getFontMetrics(newFont)
+ for (idx <- coordinate.indices) {
+ val map = coordinate(idx)
+ g2d.setColor(randomColor(transparency).darker())
+ g2d.drawRect(map("xmin"), map("ymin"), map("xmax") - map("xmin"), map("ymax") - map("ymin"))
+ // Write the name of the bounding box
+ if (names.isDefined) {
+ val x = map("xmin") - stroke.get
+ val y = map("ymin")
+ val h = fm.getHeight
+ val w = fm.charsWidth(names.get(idx).toCharArray, 0, names.get(idx).length())
+ g2d.fillRect(x, y - h, w, h)
+ g2d.setColor(Color.WHITE)
+ g2d.drawString(names.get(idx), x, y)
+ }
+ }
+ g2d.dispose()
+ }
+
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala
index 7d6f31e..f72223d 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala
@@ -20,15 +20,16 @@ package org.apache.mxnet.javaapi
import java.awt.image.BufferedImage
// scalastyle:on
import java.io.InputStream
+import scala.collection.JavaConverters._
object Image {
/**
* Decode image with OpenCV.
* Note: return image in RGB by default, instead of OpenCV's default BGR.
- * @param buf Buffer containing binary encoded image
- * @param flag Convert decoded image to grayscale (0) or color (1).
+ * @param buf Buffer containing binary encoded image
+ * @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image
- * to mxnet's default RGB format (instead of opencv's default BGR).
+ * to mxnet's default RGB format (instead of opencv's default BGR).
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = {
@@ -43,8 +44,8 @@ object Image {
* Same imageDecode with InputStream
*
* @param inputStream the inputStream of the image
- * @param flag Convert decoded image to grayscale (0) or color (1).
- * @param toRGB Whether to convert decoded image
+ * @param flag Convert decoded image to grayscale (0) or color (1).
+ * @param toRGB Whether to convert decoded image
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(inputStream: InputStream, flag: Int, toRGB: Boolean): NDArray = {
@@ -60,7 +61,7 @@ object Image {
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param filename Name of the image file to be loaded.
* @param flag Convert decoded image to grayscale (0) or color (1).
- * @param toRGB Whether to convert decoded image to mxnet's default RGB format
+ * @param toRGB Whether to convert decoded image to mxnet's default RGB format
* (instead of opencv's default BGR).
* @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
*/
@@ -74,10 +75,10 @@ object Image {
/**
* Resize image with OpenCV.
- * @param src source image in NDArray
- * @param w Width of resized image.
- * @param h Height of resized image.
- * @param interp Interpolation method (default=cv2.INTER_LINEAR).
+ * @param src source image in NDArray
+ * @param w Width of resized image.
+ * @param h Height of resized image.
+ * @param interp Interpolation method (default=cv2.INTER_LINEAR).
* @return org.apache.mxnet.NDArray
*/
def imResize(src: NDArray, w: Int, h: Int, interp: Integer): NDArray = {
@@ -92,10 +93,10 @@ object Image {
/**
* Do a fixed crop on the image
* @param src Src image in NDArray
- * @param x0 starting x point
- * @param y0 starting y point
- * @param w width of the image
- * @param h height of the image
+ * @param x0 starting x point
+ * @param y0 starting y point
+ * @param w width of the image
+ * @param h height of the image
* @return cropped NDArray
*/
def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = {
@@ -111,4 +112,21 @@ object Image {
def toImage(src: NDArray): BufferedImage = {
org.apache.mxnet.Image.toImage(src)
}
+
+ /**
+ * Draw bounding boxes on the image
+ * @param src buffered image to draw on
+ * @param coordinate Contains Map of xmin, xmax, ymin, ymax
+ * corresponding to top-left and down-right points
+ * @param names The name set of the bounding box
+ */
+ def drawBoundingBox(src: BufferedImage,
+ coordinate: java.util.List[
+ java.util.Map[java.lang.String, java.lang.Integer]],
+ names: java.util.List[java.lang.String]): Unit = {
+ val coord = coordinate.asScala.map(
+ _.asScala.map{case (name, value) => (name, Integer2int(value))}.toMap).toArray
+ org.apache.mxnet.Image.drawBoundingBox(src, coord, Option(names.asScala.toArray))
+ }
+
}
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java
index 0092744..f5515dc 100644
--- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java
@@ -20,8 +20,15 @@ package org.apache.mxnet.javaapi;
import org.apache.commons.io.FileUtils;
import org.junit.BeforeClass;
import org.junit.Test;
+
+import javax.imageio.ImageIO;
+import java.awt.image.BufferedImage;
import java.io.File;
import java.net.URL;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
import static org.junit.Assert.assertArrayEquals;
@@ -56,12 +63,23 @@ public class ImageTest {
}
@Test
- public void testImageProcess() {
+ public void testImageProcess() throws Exception {
NDArray nd = Image.imRead(imLocation, 1, true);
assertArrayEquals(nd.shape().toArray(), new int[]{576, 1024, 3});
NDArray nd2 = Image.imResize(nd, 224, 224, null);
assertArrayEquals(nd2.shape().toArray(), new int[]{224, 224, 3});
NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224);
Image.toImage(cropped);
+ BufferedImage buf = ImageIO.read(new File(imLocation));
+ Map<String, Integer> map = new HashMap<>();
+ map.put("xmin", 190);
+ map.put("xmax", 850);
+ map.put("ymin", 50);
+ map.put("ymax", 450);
+ List<Map<String, Integer>> box = new ArrayList<>();
+ box.add(map);
+ List<String> names = new ArrayList<>();
+ names.add("pug");
+ Image.drawBoundingBox(buf, box, names);
}
}
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala
index 67815ad..d4cf35a 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala
@@ -97,4 +97,25 @@ class ImageSuite extends FunSuite with BeforeAndAfterAll {
logger.info(s"converted image stored in ${tempDirPath + "/inputImages/out.png"}")
}
+ test("Test draw Bounding box") {
+ val buf = ImageIO.read(new File(imLocation))
+ val box = Array(
+ Map("xmin" -> 190, "xmax" -> 850, "ymin" -> 50, "ymax" -> 450),
+ Map("xmin" -> 200, "xmax" -> 350, "ymin" -> 440, "ymax" -> 530)
+ )
+ val names = Array("pug", "cookie")
+ Image.drawBoundingBox(buf, box, Some(names), fontSizeMult = Some(1.4f))
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ ImageIO.write(buf, "png", new File(tempDirPath + "/inputImages/out2.png"))
+ logger.info(s"converted image stored in ${tempDirPath + "/inputImages/out2.png"}")
+ for (coord <- box) {
+ val topLeft = buf.getRGB(coord("xmin"), coord("ymin"))
+ val downLeft = buf.getRGB(coord("xmin"), coord("ymax"))
+ val topRight = buf.getRGB(coord("xmax"), coord("ymin"))
+ val downRight = buf.getRGB(coord("xmax"), coord("ymax"))
+ require(downLeft == downRight)
+ require(topRight == downRight)
+ }
+ }
+
}
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md
index 8a9ed3e..4c4512f 100644
--- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md
@@ -84,7 +84,7 @@ After the previous steps, you should be able to run the code using the following
From the `scala-package/examples/scripts/infer/objectdetector/` folder run:
```bash
-./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
+./run_ssd_java_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
```
**Notes**:
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java
index a9c00f7..31b8514 100644
--- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java
@@ -28,12 +28,11 @@ import org.apache.mxnet.javaapi.*;
import org.apache.mxnet.infer.javaapi.ObjectDetector;
// scalastyle:off
+import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
// scalastyle:on
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
+import java.util.*;
import java.io.File;
@@ -128,22 +127,34 @@ public class SSDClassifierExample {
try {
Shape inputShape = new Shape(new int[]{1, 3, 512, 512});
Shape outputShape = new Shape(new int[]{1, 6132, 6});
-
-
- int width = inputShape.get(2);
- int height = inputShape.get(3);
+
StringBuilder outputStr = new StringBuilder().append("\n");
List<List<ObjectDetectorOutput>> output
= runObjectDetectionSingle(mdprefixDir, imgPath, context);
-
+
+ // Creating Bounding box material
+ BufferedImage buf = ImageIO.read(new File(imgPath));
+ int width = buf.getWidth();
+ int height = buf.getHeight();
+ List<Map<String, Integer>> boxes = new ArrayList<>();
+ List<String> names = new ArrayList<>();
for (List<ObjectDetectorOutput> ele : output) {
for (ObjectDetectorOutput i : ele) {
outputStr.append("Class: " + i.getClassName() + "\n");
outputStr.append("Probabilties: " + i.getProbability() + "\n");
-
- List<Float> coord = Arrays.asList(i.getXMin() * width,
- i.getXMax() * height, i.getYMin() * width, i.getYMax() * height);
+ names.add(i.getClassName());
+ Map<String, Integer> map = new HashMap<>();
+ float xmin = i.getXMin() * width;
+ float xmax = i.getXMax() * width;
+ float ymin = i.getYMin() * height;
+ float ymax = i.getYMax() * height;
+ List<Float> coord = Arrays.asList(xmin, xmax, ymin, ymax);
+ map.put("xmin", (int) xmin);
+ map.put("xmax", (int) xmax);
+ map.put("ymin", (int) ymin);
+ map.put("ymax", (int) ymax);
+ boxes.add(map);
StringBuilder sb = new StringBuilder();
for (float c : coord) {
sb.append(", ").append(c);
@@ -152,7 +163,12 @@ public class SSDClassifierExample {
}
}
logger.info(outputStr.toString());
-
+
+ // Covert to image
+ Image.drawBoundingBox(buf, boxes, names);
+ File outputFile = new File("boundingImage.png");
+ ImageIO.write(buf, "png", outputFile);
+
List<List<List<ObjectDetectorOutput>>> outputList =
runObjectDetectionBatch(mdprefixDir, imgDir, context);
@@ -177,7 +193,6 @@ public class SSDClassifierExample {
}
}
logger.info(outputStr.toString());
-
} catch (Exception e) {
logger.error(e.getMessage(), e);
parser.printUsage(System.err);
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala
index e29f068..28a578c 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala
@@ -132,7 +132,7 @@ class ObjectDetector(modelPathPrefix: String,
if (topK.isDefined) {
var sortedIndices = predictResult.zipWithIndex.sortBy(-_._1(1)).map(_._2)
sortedIndices = sortedIndices.take(topK.get)
- // takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax
+ // takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax]
result = sortedIndices.map(idx
=> (synset(predictResult(idx)(0).toInt),
predictResult(idx).takeRight(5))).toIndexedSeq
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala
index 5a6ac75..32fd87e 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala
@@ -52,14 +52,14 @@ class ObjectDetectorOutput (className: String, args: Array[Float]){
*
* @return Float of the max X coordinate for the object bounding box
*/
- def getXMax: Float = args(2)
+ def getXMax: Float = args(3)
/**
* Gets the minimum Y coordinate for the bounding box containing the predicted object.
*
* @return Float of the min Y coordinate for the object bounding box
*/
- def getYMin: Float = args(3)
+ def getYMin: Float = args(2)
/**
* Gets the maximum Y coordinate for the bounding box containing the predicted object.
diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java
index 04041fc..6f3df86 100644
--- a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java
+++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java
@@ -36,8 +36,8 @@ public class ObjectDetectorOutputTest {
Assert.assertEquals(odOutput.getClassName(), predictedClassName);
Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta);
- Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 2f, delta);
- Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 3f, delta);
+ Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 3f, delta);
+ Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 2f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getYMax(), 4f, delta);
}