You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2019/01/31 18:55:30 UTC

[incubator-mxnet] branch master updated: [MXNET-1180] Java Image API (#13807)

This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a3e4a0  [MXNET-1180] Java Image API (#13807)
9a3e4a0 is described below

commit 9a3e4a02ded9c6d2c304557140dac0c9991d507e
Author: Lanking <la...@live.com>
AuthorDate: Thu Jan 31 10:55:12 2019 -0800

    [MXNET-1180] Java Image API (#13807)
    
    * add java example
    
    * add test and change PredictorExample
    
    * add image change
    
    * Add minor fixes
    
    * add License
    
    * add predictor Example tests
    
    * fix the issue with JUnit test
    
    * Satisfy Lint God ʕ •ᴥ•ʔ
    
    * update the pom file config
    
    * update documentation
    
    * add simplified methods
---
 scala-package/core/pom.xml                         |   6 --
 .../src/main/scala/org/apache/mxnet/Image.scala    |   6 +-
 .../src/main/scala/org/apache/mxnet/NDArray.scala  |   6 +-
 .../scala/org/apache/mxnet/javaapi/Image.scala     | 114 +++++++++++++++++++++
 .../java/org/apache/mxnet/javaapi/ImageTest.java   |  67 ++++++++++++
 scala-package/examples/pom.xml                     |   1 +
 .../javaapi/infer/predictor/PredictorExample.java  |  88 ++--------------
 .../main/scala/org/apache/mxnetexamples/Util.scala |   4 +-
 .../infer/predictor/PredictorExampleTest.java      |  67 ++++++++++++
 scala-package/infer/pom.xml                        |   8 --
 .../apache/mxnet/javaapi/JavaNDArrayMacro.scala    |   4 +-
 scala-package/pom.xml                              |   6 ++
 12 files changed, 274 insertions(+), 103 deletions(-)

diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index 7264c39..4de65c0 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -139,12 +139,6 @@
       <scope>provided</scope>
     </dependency>
     <dependency>
-      <groupId>junit</groupId>
-      <artifactId>junit</artifactId>
-      <version>4.11</version>
-      <scope>test</scope>
-    </dependency>
-    <dependency>
       <groupId>commons-io</groupId>
       <artifactId>commons-io</artifactId>
       <version>2.1</version>
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
index 77881ab..0f756e2 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
@@ -37,7 +37,7 @@ object Image {
     * @param flag   Convert decoded image to grayscale (0) or color (1).
     * @param to_rgb Whether to convert decoded image
     *               to mxnet's default RGB format (instead of opencv's default BGR).
-    * @return NDArray in HWC format
+    * @return NDArray in HWC format with DType [[DType.UInt8]]
     */
   def imDecode(buf: Array[Byte], flag: Int,
                to_rgb: Boolean,
@@ -56,7 +56,7 @@ object Image {
   /**
     * Same imageDecode with InputStream
     * @param inputStream the inputStream of the image
-    * @return NDArray in HWC format
+    * @return NDArray in HWC format with DType [[DType.UInt8]]
     */
   def imDecode(inputStream: InputStream, flag: Int = 1,
                to_rgb: Boolean = true,
@@ -78,7 +78,7 @@ object Image {
     * @param flag     Convert decoded image to grayscale (0) or color (1).
     * @param to_rgb   Whether to convert decoded image to mxnet's default RGB format
     *                 (instead of opencv's default BGR).
-    * @return org.apache.mxnet.NDArray in HWC format
+    * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
     */
   def imRead(filename: String, flag: Option[Int] = None,
              to_rgb: Option[Boolean] = None,
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 5c345f2..4324b3d 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -97,9 +97,11 @@ object NDArray extends NDArrayBase {
           case ndArr: Seq[NDArray @unchecked] =>
             if (ndArr.head.isInstanceOf[NDArray]) (ndArr.toArray, ndArr.toArray.map(_.handle))
             else throw new IllegalArgumentException(
-              "Unsupported out var type, should be NDArray or subclass of Seq[NDArray]")
+              s"""Unsupported out ${output.getClass} type,
+                 | should be NDArray or subclass of Seq[NDArray]""".stripMargin)
           case _ => throw new IllegalArgumentException(
-            "Unsupported out var type, should be NDArray or subclass of Seq[NDArray]")
+            s"""Unsupported out ${output.getClass} type,
+               | should be NDArray or subclass of Seq[NDArray]""".stripMargin)
         }
       } else {
         (null, null)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala
new file mode 100644
index 0000000..7d6f31e
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+// scalastyle:off
+import java.awt.image.BufferedImage
+// scalastyle:on
+import java.io.InputStream
+
+object Image {
+  /**
+    * Decode image with OpenCV.
+    * Note: return image in RGB by default, instead of OpenCV's default BGR.
+    * @param buf    Buffer containing binary encoded image
+    * @param flag   Convert decoded image to grayscale (0) or color (1).
+    * @param toRGB Whether to convert decoded image
+    *               to mxnet's default RGB format (instead of opencv's default BGR).
+    * @return NDArray in HWC format with DType [[DType.UInt8]]
+    */
+  def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = {
+    org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None)
+  }
+
+  def imDecode(buf: Array[Byte]): NDArray = {
+    imDecode(buf, 1, true)
+  }
+
+  /**
+    * Same imageDecode with InputStream
+    *
+    * @param inputStream the inputStream of the image
+    * @param flag   Convert decoded image to grayscale (0) or color (1).
+    * @param toRGB Whether to convert decoded image
+    * @return NDArray in HWC format with DType [[DType.UInt8]]
+    */
+  def imDecode(inputStream: InputStream, flag: Int, toRGB: Boolean): NDArray = {
+    org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None)
+  }
+
+  def imDecode(inputStream: InputStream): NDArray = {
+    imDecode(inputStream, 1, true)
+  }
+
+  /**
+    * Read and decode image with OpenCV.
+    * Note: return image in RGB by default, instead of OpenCV's default BGR.
+    * @param filename Name of the image file to be loaded.
+    * @param flag     Convert decoded image to grayscale (0) or color (1).
+    * @param toRGB   Whether to convert decoded image to mxnet's default RGB format
+    *                 (instead of opencv's default BGR).
+    * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
+    */
+  def imRead(filename: String, flag: Int, toRGB: Boolean): NDArray = {
+    org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None)
+  }
+
+  def imRead(filename: String): NDArray = {
+    imRead(filename, 1, true)
+  }
+
+  /**
+    * Resize image with OpenCV.
+    * @param src     source image in NDArray
+    * @param w       Width of resized image.
+    * @param h       Height of resized image.
+    * @param interp  Interpolation method (default=cv2.INTER_LINEAR).
+    * @return org.apache.mxnet.NDArray
+    */
+  def imResize(src: NDArray, w: Int, h: Int, interp: Integer): NDArray = {
+    val interpVal = if (interp == null) None else Some(interp.intValue())
+    org.apache.mxnet.Image.imResize(src, w, h, interpVal, None)
+  }
+
+  def imResize(src: NDArray, w: Int, h: Int): NDArray = {
+    imResize(src, w, h, null)
+  }
+
+  /**
+    * Do a fixed crop on the image
+    * @param src Src image in NDArray
+    * @param x0 starting x point
+    * @param y0 starting y point
+    * @param w width of the image
+    * @param h height of the image
+    * @return cropped NDArray
+    */
+  def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = {
+    org.apache.mxnet.Image.fixedCrop(src, x0, y0, w, h)
+  }
+
+  /**
+    * Convert a NDArray image to a real image
+    * The time cost will increase if the image resolution is big
+    * @param src Source image file in RGB
+    * @return Buffered Image
+    */
+  def toImage(src: NDArray): BufferedImage = {
+    org.apache.mxnet.Image.toImage(src)
+  }
+}
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java
new file mode 100644
index 0000000..0092744
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import org.apache.commons.io.FileUtils;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import java.io.File;
+import java.net.URL;
+
+import static org.junit.Assert.assertArrayEquals;
+
+public class ImageTest {
+
+    private static String imLocation;
+
+    private static void downloadUrl(String url, String filePath, int maxRetry) throws Exception{
+        File tmpFile = new File(filePath);
+        Boolean success = false;
+        if (!tmpFile.exists()) {
+            while (maxRetry > 0 && !success) {
+                try {
+                    FileUtils.copyURLToFile(new URL(url), tmpFile);
+                    success = true;
+                } catch(Exception e){
+                   maxRetry -= 1;
+                }
+            }
+        } else {
+            success = true;
+        }
+        if (!success) throw new Exception("$url Download failed!");
+    }
+
+    @BeforeClass
+    public static void downloadFile() throws Exception {
+        String tempDirPath = System.getProperty("java.io.tmpdir");
+        imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg";
+        downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
+                imLocation, 3);
+    }
+
+    @Test
+    public void testImageProcess() {
+        NDArray nd = Image.imRead(imLocation, 1, true);
+        assertArrayEquals(nd.shape().toArray(), new int[]{576, 1024, 3});
+        NDArray nd2 = Image.imResize(nd, 224, 224, null);
+        assertArrayEquals(nd2.shape().toArray(), new int[]{224, 224, 3});
+        NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224);
+        Image.toImage(cropped);
+    }
+}
diff --git a/scala-package/examples/pom.xml b/scala-package/examples/pom.xml
index 564102a..30ccfdc 100644
--- a/scala-package/examples/pom.xml
+++ b/scala-package/examples/pom.xml
@@ -15,6 +15,7 @@
 
   <properties>
     <skipTests>true</skipTests>
+    <skipJavaTests>${skipTests}</skipJavaTests>
   </properties>
 
   <build>
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java
index c9b4426..c5d2099 100644
--- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java
@@ -24,8 +24,6 @@ import org.kohsuke.args4j.Option;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import javax.imageio.ImageIO;
-import java.awt.Graphics2D;
 import java.awt.image.BufferedImage;
 import java.io.BufferedReader;
 import java.io.File;
@@ -47,76 +45,7 @@ public class PredictorExample {
     private String inputImagePath = "/images/dog.jpg";
 
     final static Logger logger = LoggerFactory.getLogger(PredictorExample.class);
-
-    /**
-     * Load the image from file to buffered image
-     * It can be replaced by loadImageFromFile from ObjectDetector
-     * @param inputImagePath input image Path in String
-     * @return Buffered image
-     */
-    private static BufferedImage loadIamgeFromFile(String inputImagePath) {
-        BufferedImage buf = null;
-        try {
-            buf = ImageIO.read(new File(inputImagePath));
-        } catch (IOException e) {
-            System.err.println(e);
-        }
-        return buf;
-    }
-
-    /**
-     * Reshape the current image using ImageIO and Graph2D
-     * It can be replaced by reshapeImage from ObjectDetector
-     * @param buf Buffered image
-     * @param newWidth desired width
-     * @param newHeight desired height
-     * @return a reshaped bufferedImage
-     */
-    private static BufferedImage reshapeImage(BufferedImage buf, int newWidth, int newHeight) {
-        BufferedImage resizedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_RGB);
-        Graphics2D g = resizedImage.createGraphics();
-        g.drawImage(buf, 0, 0, newWidth, newHeight, null);
-        g.dispose();
-        return resizedImage;
-    }
-
-    /**
-     * Convert an image from a buffered image into pixels float array
-     * It can be replaced by bufferedImageToPixels from ObjectDetector
-     * @param buf buffered image
-     * @return Float array
-     */
-    private static float[] imagePreprocess(BufferedImage buf) {
-        // Get height and width of the image
-        int w = buf.getWidth();
-        int h = buf.getHeight();
-
-        // get an array of integer pixels in the default RGB color mode
-        int[] pixels = buf.getRGB(0, 0, w, h, null, 0, w);
-
-        // 3 times height and width for R,G,B channels
-        float[] result = new float[3 * h * w];
-
-        int row = 0;
-        // copy pixels to array vertically
-        while (row < h) {
-            int col = 0;
-            // copy pixels to array horizontally
-            while (col < w) {
-                int rgb = pixels[row * w + col];
-                // getting red color
-                result[0 * h * w + row * w + col] = (rgb >> 16) & 0xFF;
-                // getting green color
-                result[1 * h * w + row * w + col] = (rgb >> 8) & 0xFF;
-                // getting blue color
-                result[2 * h * w + row * w + col] = rgb & 0xFF;
-                col += 1;
-            }
-            row += 1;
-        }
-        buf.flush();
-        return result;
-    }
+    private static NDArray$ NDArray = NDArray$.MODULE$;
 
     /**
      * Helper class to print the maximum prediction result
@@ -170,11 +99,10 @@ public class PredictorExample {
         inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
         Predictor predictor = new Predictor(inst.modelPathPrefix, inputDesc, context,0);
         // Prepare data
-        BufferedImage img = loadIamgeFromFile(inst.inputImagePath);
-
-        img = reshapeImage(img, 224, 224);
+        NDArray img = Image.imRead(inst.inputImagePath, 1, true);
+        img = Image.imResize(img, 224, 224, null);
         // predict
-        float[][] result = predictor.predict(new float[][]{imagePreprocess(img)});
+        float[][] result = predictor.predict(new float[][]{img.toArray()});
         try {
             System.out.println("Predict with Float input");
             System.out.println(printMaximumClass(result[0], inst.modelPathPrefix));
@@ -182,10 +110,10 @@ public class PredictorExample {
             System.err.println(e);
         }
         // predict with NDArray
-        NDArray nd = new NDArray(
-                imagePreprocess(img),
-                new Shape(new int[]{1, 3, 224, 224}),
-                Context.cpu());
+        NDArray nd = img;
+        nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0];
+        nd = NDArray.expand_dims(nd, 0, null)[0];
+        nd = nd.asType(DType.Float32());
         List<NDArray> ndList = new ArrayList<>();
         ndList.add(nd);
         List<NDArray> ndResult = predictor.predictWithNDArray(ndList);
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala
index c1ff10c..dba34316 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala
@@ -24,9 +24,9 @@ import org.apache.commons.io.FileUtils
 
 object Util {
 
-  def downloadUrl(url: String, filePath: String, maxRetry: Option[Int] = None) : Unit = {
+  def downloadUrl(url: String, filePath: String, maxRetry: Int = 3) : Unit = {
     val tmpFile = new File(filePath)
-    var retry = maxRetry.getOrElse(3)
+    var retry = maxRetry
     var success = false
     if (!tmpFile.exists()) {
       while (retry > 0 && !success) {
diff --git a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java
new file mode 100644
index 0000000..30bc8db
--- /dev/null
+++ b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.javaapi.infer.predictor;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.apache.mxnetexamples.Util;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+
+public class PredictorExampleTest {
+
+    final static Logger logger = LoggerFactory.getLogger(PredictorExampleTest.class);
+    private static String modelPathPrefix = "";
+    private static String inputImagePath = "";
+
+    @BeforeClass
+    public static void downloadFile() {
+        logger.info("Downloading resnet-18 model");
+
+        String tempDirPath = System.getProperty("java.io.tmpdir");
+        logger.info("tempDirPath: %s".format(tempDirPath));
+
+        String baseUrl = "https://s3.us-east-2.amazonaws.com/scala-infer-models";
+
+        Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-symbol.json",
+                tempDirPath + "/resnet18/resnet-18-symbol.json", 3);
+        Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-0000.params",
+                tempDirPath + "/resnet18/resnet-18-0000.params", 3);
+        Util.downloadUrl(baseUrl + "/resnet-18/synset.txt",
+                tempDirPath + "/resnet18/synset.txt", 3);
+        Util.downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
+                tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg", 3);
+
+        modelPathPrefix = tempDirPath + File.separator + "resnet18/resnet-18";
+        inputImagePath = tempDirPath + File.separator +
+                "inputImages/resnet18/Pug-Cookie.jpg";
+    }
+
+    @Test
+    public void testPredictor(){
+        PredictorExample example = new PredictorExample();
+        String[] args = new String[]{
+                "--model-path-prefix", modelPathPrefix,
+                "--input-image", inputImagePath
+        };
+        example.main(args);
+    }
+
+}
diff --git a/scala-package/infer/pom.xml b/scala-package/infer/pom.xml
index 13ceebb..565ac6e 100644
--- a/scala-package/infer/pom.xml
+++ b/scala-package/infer/pom.xml
@@ -64,13 +64,5 @@
       <version>1.10.19</version>
       <scope>test</scope>
     </dependency>
-
-    <dependency>
-      <groupId>junit</groupId>
-      <artifactId>junit</artifactId>
-      <version>4.11</version>
-      <scope>test</scope>
-    </dependency>
-
   </dependencies>
 </project>
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
index 4dfd6eb..fa3565b 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
@@ -96,9 +96,9 @@ private[mxnet] object JavaNDArrayMacro extends GeneratorBase {
       // add default out parameter
       argDef += s"out: org.apache.mxnet.javaapi.NDArray"
       if (useParamObject) {
-        impl += "if (po.getOut() != null) map(\"out\") = po.getOut()"
+        impl += "if (po.getOut() != null) map(\"out\") = po.getOut().nd"
       } else {
-        impl += "if (out != null) map(\"out\") = out"
+        impl += "if (out != null) map(\"out\") = out.nd"
       }
       val returnType = "Array[org.apache.mxnet.javaapi.NDArray]"
       // scalastyle:off
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 07baeab..5ba6f1f 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -412,6 +412,12 @@
       <version>1.13.5</version>
       <scope>test</scope>
     </dependency>
+    <dependency>
+      <groupId>junit</groupId>
+      <artifactId>junit</artifactId>
+      <version>4.11</version>
+      <scope>test</scope>
+    </dependency>
 
     <!-- Following libraries are required by running javah, they should be excluded from .jar -->
     <dependency>