You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/02 15:02:50 UTC

[GitHub] gigasquid closed pull request #12995: [MXNET-1180] Scala Image API

gigasquid closed pull request #12995: [MXNET-1180] Scala Image API
URL: https://github.com/apache/incubator-mxnet/pull/12995
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
new file mode 100644
index 00000000000..43f81a22a40
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+// scalastyle:off
+import java.awt.image.BufferedImage
+// scalastyle:on
+import java.io.InputStream
+
+import scala.collection.mutable
+import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+
+/**
+  * Image API of Scala package
+  * enable OpenCV feature
+  */
+object Image {
+
+  /**
+    * Decode image with OpenCV.
+    * Note: return image in RGB by default, instead of OpenCV's default BGR.
+    * @param buf    Buffer containing binary encoded image
+    * @param flag   Convert decoded image to grayscale (0) or color (1).
+    * @param to_rgb Whether to convert decoded image
+    *               to mxnet's default RGB format (instead of opencv's default BGR).
+    * @return NDArray in HWC format
+    */
+  def imDecode(buf: Array[Byte], flag: Int,
+               to_rgb: Boolean,
+               out: Option[NDArray]): NDArray = {
+    val nd = NDArray.array(buf.map(_.toFloat), Shape(buf.length))
+    val byteND = NDArray.api.cast(nd, "uint8")
+    val args : ListBuffer[Any] = ListBuffer()
+    val map : mutable.Map[String, Any] = mutable.Map()
+    args += byteND
+    map("flag") = flag
+    map("to_rgb") = to_rgb
+    if (out.isDefined) map("out") = out.get
+    NDArray.genericNDArrayFunctionInvoke("_cvimdecode", args, map.toMap)
+  }
+
+  /**
+    * Same imageDecode with InputStream
+    * @param inputStream the inputStream of the image
+    * @return NDArray in HWC format
+    */
+  def imDecode(inputStream: InputStream, flag: Int = 1,
+               to_rgb: Boolean = true,
+               out: Option[NDArray] = None): NDArray = {
+    val buffer = new Array[Byte](2048)
+    val arrBuffer = ArrayBuffer[Byte]()
+    var length = 0
+    while (length != -1) {
+      length = inputStream.read(buffer)
+      if (length != -1) arrBuffer ++= buffer.slice(0, length)
+    }
+    imDecode(arrBuffer.toArray, flag, to_rgb, out)
+  }
+
+  /**
+    * Read and decode image with OpenCV.
+    * Note: return image in RGB by default, instead of OpenCV's default BGR.
+    * @param filename Name of the image file to be loaded.
+    * @param flag     Convert decoded image to grayscale (0) or color (1).
+    * @param to_rgb   Whether to convert decoded image to mxnet's default RGB format
+    *                 (instead of opencv's default BGR).
+    * @return org.apache.mxnet.NDArray in HWC format
+    */
+  def imRead(filename: String, flag: Option[Int] = None,
+             to_rgb: Option[Boolean] = None,
+             out: Option[NDArray] = None): NDArray = {
+    val args : ListBuffer[Any] = ListBuffer()
+    val map : mutable.Map[String, Any] = mutable.Map()
+    map("filename") = filename
+    if (flag.isDefined) map("flag") = flag.get
+    if (to_rgb.isDefined) map("to_rgb") = to_rgb.get
+    if (out.isDefined) map("out") = out.get
+    NDArray.genericNDArrayFunctionInvoke("_cvimread", args, map.toMap)
+  }
+
+  /**
+    * Resize image with OpenCV.
+    * @param src     source image in NDArray
+    * @param w       Width of resized image.
+    * @param h       Height of resized image.
+    * @param interp  Interpolation method (default=cv2.INTER_LINEAR).
+    * @return org.apache.mxnet.NDArray
+    */
+  def imResize(src: org.apache.mxnet.NDArray, w: Int, h: Int,
+               interp: Option[Int] = None,
+               out: Option[NDArray] = None): NDArray = {
+    val args : ListBuffer[Any] = ListBuffer()
+    val map : mutable.Map[String, Any] = mutable.Map()
+    args += src
+    map("w") = w
+    map("h") = h
+    if (interp.isDefined) map("interp") = interp.get
+    if (out.isDefined) map("out") = out.get
+    NDArray.genericNDArrayFunctionInvoke("_cvimresize", args, map.toMap)
+  }
+
+  /**
+    * Pad image border with OpenCV.
+    * @param src    source image
+    * @param top    Top margin.
+    * @param bot    Bottom margin.
+    * @param left   Left margin.
+    * @param right  Right margin.
+    * @param typeOf Filling type (default=cv2.BORDER_CONSTANT).
+    * @param value  (Deprecated! Use ``values`` instead.) Fill with single value.
+    * @param values Fill with value(RGB[A] or gray), up to 4 channels.
+    * @return org.apache.mxnet.NDArray
+    */
+  def copyMakeBorder(src: org.apache.mxnet.NDArray, top: Int, bot: Int,
+                     left: Int, right: Int, typeOf: Option[Int] = None,
+                     value: Option[Double] = None, values: Option[Any] = None,
+                     out: Option[NDArray] = None): NDArray = {
+    val args : ListBuffer[Any] = ListBuffer()
+    val map : mutable.Map[String, Any] = mutable.Map()
+    args += src
+    map("top") = top
+    map("bot") = bot
+    map("left") = left
+    map("right") = right
+    if (typeOf.isDefined) map("type") = typeOf.get
+    if (value.isDefined) map("value") = value.get
+    if (values.isDefined) map("values") = values.get
+    if (out.isDefined) map("out") = out.get
+    NDArray.genericNDArrayFunctionInvoke("_cvcopyMakeBorder", args, map.toMap)
+  }
+
+  /**
+    * Do a fixed crop on the image
+    * @param src Src image in NDArray
+    * @param x0 starting x point
+    * @param y0 starting y point
+    * @param w width of the image
+    * @param h height of the image
+    * @return cropped NDArray
+    */
+  def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = {
+    NDArray.api.crop(src, Shape(y0, x0, 0), Shape(y0 + h, x0 + w, src.shape.get(2)))
+  }
+
+  /**
+    * Convert a NDArray image to a real image
+    * The time cost will increase if the image resolution is big
+    * @param src Source image file in RGB
+    * @return Buffered Image
+    */
+  def toImage(src: NDArray): BufferedImage = {
+    require(src.dtype == DType.UInt8, "The input NDArray must be bytes")
+    require(src.shape.length == 3, "The input should contains height, width and channel")
+    val height = src.shape.get(0)
+    val width = src.shape.get(1)
+    val img = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB)
+    (0 until height).par.foreach(r => {
+      (0 until width).par.foreach(c => {
+        val arr = src.at(r).at(c).toArray
+        // NDArray in RGB
+        val red = arr(0).toByte & 0xFF
+        val green = arr(1).toByte & 0xFF
+        val blue = arr(2).toByte & 0xFF
+        val rgb = (red << 16) | (green << 8) | blue
+        img.setRGB(c, r, rgb)
+      })
+    })
+    img
+  }
+
+}
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala
new file mode 100644
index 00000000000..67815ad6c10
--- /dev/null
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import java.io.File
+import java.net.URL
+
+import javax.imageio.ImageIO
+import org.apache.commons.io.FileUtils
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.slf4j.LoggerFactory
+
+class ImageSuite extends FunSuite with BeforeAndAfterAll {
+  private var imLocation = ""
+  private val logger = LoggerFactory.getLogger(classOf[ImageSuite])
+
+  private def downloadUrl(url: String, filePath: String, maxRetry: Option[Int] = None) : Unit = {
+    val tmpFile = new File(filePath)
+    var retry = maxRetry.getOrElse(3)
+    var success = false
+    if (!tmpFile.exists()) {
+      while (retry > 0 && !success) {
+        try {
+          FileUtils.copyURLToFile(new URL(url), tmpFile)
+          success = true
+        } catch {
+          case e: Exception => retry -= 1
+        }
+      }
+    } else {
+      success = true
+    }
+    if (!success) throw new Exception(s"$url Download failed!")
+  }
+
+  override def beforeAll(): Unit = {
+    val tempDirPath = System.getProperty("java.io.tmpdir")
+    imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg"
+    downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
+      imLocation)
+  }
+
+  test("Test load image") {
+    val nd = Image.imRead(imLocation)
+    logger.info(s"OpenCV load image with shape: ${nd.shape}")
+    require(nd.shape == Shape(576, 1024, 3), "image shape not Match!")
+  }
+
+  test("Test load image from Socket") {
+    val url = new URL("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg")
+    val inputStream = url.openStream
+    val nd = Image.imDecode(inputStream)
+    logger.info(s"OpenCV load image with shape: ${nd.shape}")
+    require(nd.shape == Shape(576, 1024, 3), "image shape not Match!")
+  }
+
+  test("Test resize image") {
+    val nd = Image.imRead(imLocation)
+    val resizeIm = Image.imResize(nd, 224, 224)
+    logger.info(s"OpenCV resize image with shape: ${resizeIm.shape}")
+    require(resizeIm.shape == Shape(224, 224, 3), "image shape not Match!")
+  }
+
+  test("Test crop image") {
+    val nd = Image.imRead(imLocation)
+    val nd2 = Image.fixedCrop(nd, 0, 0, 224, 224)
+    require(nd2.shape == Shape(224, 224, 3), "image shape not Match!")
+  }
+
+  test("Test apply border") {
+    val nd = Image.imRead(imLocation)
+    val nd2 = Image.copyMakeBorder(nd, 1, 1, 1, 1)
+    require(nd2.shape == Shape(578, 1026, 3), s"image shape not Match!")
+  }
+
+  test("Test convert to Image") {
+    val nd = Image.imRead(imLocation)
+    val resizeIm = Image.imResize(nd, 224, 224)
+    val tempDirPath = System.getProperty("java.io.tmpdir")
+    val img = Image.toImage(resizeIm)
+    ImageIO.write(img, "png", new File(tempDirPath + "/inputImages/out.png"))
+    logger.info(s"converted image stored in ${tempDirPath + "/inputImages/out.png"}")
+  }
+
+}


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services