You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cj...@apache.org on 2018/01/16 17:02:41 UTC

[incubator-mxnet] branch master updated: Fix crash when opening an image, fix exception safety. (#9370)

This is an automated email from the ASF dual-hosted git repository.

cjolivier01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8054290  Fix crash when opening an image, fix exception safety. (#9370)
8054290 is described below

commit 80542906eab56f186cc395621ab72c497194cd4b
Author: Pedro Larroy <92...@users.noreply.github.com>
AuthorDate: Tue Jan 16 18:02:36 2018 +0100

    Fix crash when opening an image, fix exception safety. (#9370)
    
    * Fixes multiple problems in imread
    
    - Fix exception safe temporary buffer on imread.
    - Fix bad_alloc crash when running imdecode on non-existing file.
    
    * Improve test_image, tar file is not decompressed every time making test faster
    added test_imread_vs_imdecode
---
 src/io/image_io.cc                  |  21 +--
 tests/python/unittest/test_image.py | 284 +++++++++++++++++++-----------------
 2 files changed, 166 insertions(+), 139 deletions(-)

diff --git a/src/io/image_io.cc b/src/io/image_io.cc
index e26736f..f6183a1 100644
--- a/src/io/image_io.cc
+++ b/src/io/image_io.cc
@@ -35,6 +35,7 @@
 #include <nnvm/tuple.h>
 
 #include <fstream>
+#include <cstring>
 
 #include "../operator/elemwise_op_common.h"
 
@@ -218,29 +219,31 @@ void Imread(const nnvm::NodeAttrs& attrs,
   const auto& param = nnvm::get<ImreadParam>(attrs.parsed);
 
   std::ifstream file(param.filename, std::ios::binary | std::ios::ate);
+  // if file is not open we get bad alloc after tellg
+  CHECK(file.is_open()) << "Imread: '" << param.filename
+      << "' couldn't open file: " << strerror(errno);
   size_t fsize = file.tellg();
   file.seekg(0, std::ios::beg);
-  auto buff = new uint8_t[fsize];
-  file.read(reinterpret_cast<char*>(buff), fsize);
-  CHECK(file.good()) << "Failed reading image file " << param.filename;
+  std::shared_ptr<uint8_t> buff(new uint8_t[fsize], std::default_delete<uint8_t[]>());
+  file.read(reinterpret_cast<char*>(buff.get()), fsize);
+  CHECK(file.good()) << "Failed reading image file: '" << param.filename << "' "
+            << strerror(errno);
 
   TShape oshape(3);
   oshape[2] = param.flag == 0 ? 1 : 3;
-  if (get_jpeg_size(buff, fsize, &oshape[1], &oshape[0])) {
-  } else if (get_png_size(buff, fsize, &oshape[1], &oshape[0])) {
+  if (get_jpeg_size(buff.get(), fsize, &oshape[1], &oshape[0])) {
+  } else if (get_png_size(buff.get(), fsize, &oshape[1], &oshape[0])) {
   } else {
     (*outputs)[0] = NDArray();
-    ImdecodeImpl(param.flag, param.to_rgb, buff, fsize, &((*outputs)[0]));
-    delete[] buff;
+    ImdecodeImpl(param.flag, param.to_rgb, buff.get(), fsize, &((*outputs)[0]));
     return;
   }
 
   NDArray& ndout = (*outputs)[0];
   ndout = NDArray(oshape, Context::CPU(), true, mshadow::kUint8);
   Engine::Get()->PushSync([ndout, buff, fsize, param](RunContext ctx){
-      ImdecodeImpl(param.flag, param.to_rgb, buff, fsize,
+      ImdecodeImpl(param.flag, param.to_rgb, buff.get(), fsize,
                    const_cast<NDArray*>(&ndout));
-      delete[] buff;
     }, ndout.ctx(), {}, {ndout.var()},
     FnProperty::kNormal, 0, PROFILER_MESSAGE("Imread"));
 #else
diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py
index 04b878d..124c94c 100644
--- a/tests/python/unittest/test_image.py
+++ b/tests/python/unittest/test_image.py
@@ -19,7 +19,11 @@ import mxnet as mx
 import numpy as np
 from mxnet.test_utils import *
 from common import assertRaises
+import shutil
+import tempfile
+import unittest
 
+from nose.tools import raises
 
 def _get_data(url, dirname):
     import os, tarfile
@@ -33,100 +37,6 @@ def _get_data(url, dirname):
     tar.close()
     return source_images
 
-def _get_images():
-    return _get_data("http://data.mxnet.io/data/test_images.tar.gz", './data')
-
-def test_init():
-    _get_images()
-
-def test_imdecode():
-    try:
-        import cv2
-    except ImportError:
-        return
-    sources = _get_images()
-    for img in sources:
-        with open(img, 'rb') as fp:
-            str_image = fp.read()
-            image = mx.image.imdecode(str_image, to_rgb=0)
-        cv_image = cv2.imread(img)
-        assert_almost_equal(image.asnumpy(), cv_image)
-
-def test_scale_down():
-    assert mx.image.scale_down((640, 480), (720, 120)) == (640, 106)
-    assert mx.image.scale_down((360, 1000), (480, 500)) == (360, 375)
-    assert mx.image.scale_down((300, 400), (0, 0)) == (0, 0)
-
-def test_resize_short():
-    try:
-        import cv2
-    except ImportError:
-        return
-    sources = _get_images()
-    for img in sources:
-        cv_img = cv2.imread(img)
-        mx_img = mx.nd.array(cv_img[:, :, (2, 1, 0)])
-        h, w, _ = cv_img.shape
-        for _ in range(3):
-            new_size = np.random.randint(1, 1000)
-            if h > w:
-                new_h, new_w = new_size * h / w, new_size
-            else:
-                new_h, new_w = new_size, new_size * w / h
-            for interp in range(0, 2):
-                # area-based/lanczos don't match with cv2?
-                cv_resized = cv2.resize(cv_img, (new_w, new_h), interpolation=interp)
-                mx_resized = mx.image.resize_short(mx_img, new_size, interp)
-                assert_almost_equal(mx_resized.asnumpy()[:, :, (2, 1, 0)], cv_resized, atol=3)
-
-def test_color_normalize():
-    for _ in range(10):
-        mean = np.random.rand(3) * 255
-        std = np.random.rand(3) + 1
-        width = np.random.randint(100, 500)
-        height = np.random.randint(100, 500)
-        src = np.random.rand(height, width, 3) * 255.
-        mx_result = mx.image.color_normalize(mx.nd.array(src),
-            mx.nd.array(mean), mx.nd.array(std))
-        assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3)
-
-
-def test_imageiter():
-    sources = _get_images()
-    im_list = [[np.random.randint(0, 5), x] for x in sources]
-    test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list,
-        path_root='')
-    for _ in range(3):
-        for batch in test_iter:
-            pass
-        test_iter.reset()
-
-    # test with list file
-    fname = './data/test_imageiter.lst'
-    file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \
-        for k, x in enumerate(sources)]
-    with open(fname, 'w') as f:
-        for line in file_list:
-            f.write(line + '\n')
-
-    test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, path_imglist=fname,
-        path_root='')
-    for batch in test_iter:
-        pass
-
-
-def test_augmenters():
-    # only test if all augmenters will work
-    # TODO(Joshua Zhang): verify the augmenter outputs
-    sources = _get_images()
-    im_list = [[0, x] for x in sources]
-    test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list,
-        resize=640, rand_crop=True, rand_resize=True, rand_mirror=True, mean=True,
-        std=np.array([1.1, 1.03, 1.05]), brightness=0.1, contrast=0.1, saturation=0.1,
-        hue=0.1, pca_noise=0.1, rand_gray=0.2, inter_method=10, path_root='', shuffle=True)
-    for batch in test_iter:
-        pass
-
 def _generate_objects():
     num = np.random.randint(1, 10)
     xy = np.random.rand(num, 2)
@@ -140,44 +50,158 @@ def _generate_objects():
     label = np.hstack((cid[:, np.newaxis], boxes)).ravel().tolist()
     return [2, 5] + label
 
-def test_image_detiter():
-    sources = _get_images()
-    im_list = [_generate_objects() + [x] for x in sources]
-    det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
-    for _ in range(3):
+
+class TestImage(unittest.TestCase):
+    IMAGES_URL = "http://data.mxnet.io/data/test_images.tar.gz"
+    IMAGES = []
+    IMAGES_DIR = None
+
+    @classmethod
+    def setupClass(cls):
+        cls.IMAGES_DIR = tempfile.mkdtemp()
+        cls.IMAGES = _get_data(cls.IMAGES_URL, cls.IMAGES_DIR)
+        print("Loaded {} images".format(len(cls.IMAGES)))
+
+    @classmethod
+    def teardownClass(cls):
+        if cls.IMAGES_DIR:
+            print("cleanup {}".format(cls.IMAGES_DIR))
+            shutil.rmtree(cls.IMAGES_DIR)
+
+    @raises(mx.base.MXNetError)
+    def test_imread_not_found(self):
+        x = mx.img.image.imread("/139810923jadjsajlskd.___adskj/blah.jpg")
+
+    def test_imread_vs_imdecode(self):
+        for img in TestImage.IMAGES:
+            with open(img, 'rb') as fp:
+                str_image = fp.read()
+                image = mx.image.imdecode(str_image, to_rgb=0)
+                image_read = mx.img.image.imread(img)
+                same(image.asnumpy(), image_read.asnumpy())
+
+
+    def test_imdecode(self):
+        try:
+            import cv2
+        except ImportError:
+            return
+        for img in TestImage.IMAGES:
+            with open(img, 'rb') as fp:
+                str_image = fp.read()
+                image = mx.image.imdecode(str_image, to_rgb=0)
+            cv_image = cv2.imread(img)
+            assert_almost_equal(image.asnumpy(), cv_image)
+
+    def test_scale_down(self):
+        assert mx.image.scale_down((640, 480), (720, 120)) == (640, 106)
+        assert mx.image.scale_down((360, 1000), (480, 500)) == (360, 375)
+        assert mx.image.scale_down((300, 400), (0, 0)) == (0, 0)
+
+    def test_resize_short(self):
+        try:
+            import cv2
+        except ImportError:
+            return
+        for img in TestImage.IMAGES:
+            cv_img = cv2.imread(img)
+            mx_img = mx.nd.array(cv_img[:, :, (2, 1, 0)])
+            h, w, _ = cv_img.shape
+            for _ in range(3):
+                new_size = np.random.randint(1, 1000)
+                if h > w:
+                    new_h, new_w = new_size * h / w, new_size
+                else:
+                    new_h, new_w = new_size, new_size * w / h
+                for interp in range(0, 2):
+                    # area-based/lanczos don't match with cv2?
+                    cv_resized = cv2.resize(cv_img, (new_w, new_h), interpolation=interp)
+                    mx_resized = mx.image.resize_short(mx_img, new_size, interp)
+                    assert_almost_equal(mx_resized.asnumpy()[:, :, (2, 1, 0)], cv_resized, atol=3)
+
+    def test_color_normalize(self):
+        for _ in range(10):
+            mean = np.random.rand(3) * 255
+            std = np.random.rand(3) + 1
+            width = np.random.randint(100, 500)
+            height = np.random.randint(100, 500)
+            src = np.random.rand(height, width, 3) * 255.
+            mx_result = mx.image.color_normalize(mx.nd.array(src),
+                mx.nd.array(mean), mx.nd.array(std))
+            assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3)
+
+
+    def test_imageiter(self):
+        im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
+        test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list,
+            path_root='')
+        for _ in range(3):
+            for batch in test_iter:
+                pass
+            test_iter.reset()
+
+        # test with list file
+        fname = './data/test_imageiter.lst'
+        file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \
+            for k, x in enumerate(TestImage.IMAGES)]
+        with open(fname, 'w') as f:
+            for line in file_list:
+                f.write(line + '\n')
+
+        test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, path_imglist=fname,
+            path_root='')
+        for batch in test_iter:
+            pass
+
+
+    def test_augmenters(self):
+        # only test if all augmenters will work
+        # TODO(Joshua Zhang): verify the augmenter outputs
+        im_list = [[0, x] for x in TestImage.IMAGES]
+        test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list,
+            resize=640, rand_crop=True, rand_resize=True, rand_mirror=True, mean=True,
+            std=np.array([1.1, 1.03, 1.05]), brightness=0.1, contrast=0.1, saturation=0.1,
+            hue=0.1, pca_noise=0.1, rand_gray=0.2, inter_method=10, path_root='', shuffle=True)
+        for batch in test_iter:
+            pass
+
+
+    def test_image_detiter(self):
+        im_list = [_generate_objects() + [x] for x in TestImage.IMAGES]
+        det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
+        for _ in range(3):
+            for batch in det_iter:
+                pass
+            det_iter.reset()
+
+        val_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
+        det_iter = val_iter.sync_label_shape(det_iter)
+
+        # test file list
+        fname = './data/test_imagedetiter.lst'
+        im_list = [[k] + _generate_objects() + [x] for k, x in enumerate(TestImage.IMAGES)]
+        with open(fname, 'w') as f:
+            for line in im_list:
+                line = '\t'.join([str(k) for k in line])
+                f.write(line + '\n')
+
+        det_iter = mx.image.ImageDetIter(2, (3, 400, 400), path_imglist=fname,
+            path_root='')
+        for batch in det_iter:
+            pass
+
+    def test_det_augmenters(self):
+        # only test if all augmenters will work
+        # TODO(Joshua Zhang): verify the augmenter outputs
+        im_list = [_generate_objects() + [x] for x in TestImage.IMAGES]
+        det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='',
+            resize=640, rand_crop=1, rand_pad=1, rand_gray=0.1, rand_mirror=True, mean=True,
+            std=np.array([1.1, 1.03, 1.05]), brightness=0.1, contrast=0.1, saturation=0.1,
+            pca_noise=0.1, hue=0.1, inter_method=10, min_object_covered=0.5,
+            aspect_ratio_range=(0.2, 5), area_range=(0.1, 4.0), min_eject_coverage=0.5,
+            max_attempts=50)
         for batch in det_iter:
             pass
-        det_iter.reset()
-
-    val_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
-    det_iter = val_iter.sync_label_shape(det_iter)
-
-    # test file list
-    fname = './data/test_imagedetiter.lst'
-    im_list = [[k] + _generate_objects() + [x] for k, x in enumerate(sources)]
-    with open(fname, 'w') as f:
-        for line in im_list:
-            line = '\t'.join([str(k) for k in line])
-            f.write(line + '\n')
-
-    det_iter = mx.image.ImageDetIter(2, (3, 400, 400), path_imglist=fname,
-        path_root='')
-    for batch in det_iter:
-        pass
-
-def test_det_augmenters():
-    # only test if all augmenters will work
-    # TODO(Joshua Zhang): verify the augmenter outputs
-    sources = _get_images()
-    im_list = [_generate_objects() + [x] for x in sources]
-    det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='',
-        resize=640, rand_crop=1, rand_pad=1, rand_gray=0.1, rand_mirror=True, mean=True,
-        std=np.array([1.1, 1.03, 1.05]), brightness=0.1, contrast=0.1, saturation=0.1,
-        pca_noise=0.1, hue=0.1, inter_method=10, min_object_covered=0.5,
-        aspect_ratio_range=(0.2, 5), area_range=(0.1, 4.0), min_eject_coverage=0.5,
-        max_attempts=50)
-    for batch in det_iter:
-        pass
 
 if __name__ == '__main__':
     import nose

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].