You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/10 17:20:42 UTC

[incubator-mxnet] branch master updated: add verification to gluon dataset (#7322)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0d1407f  add verification to gluon dataset (#7322)
0d1407f is described below

commit 0d1407fb69c1f3a71ef6c8d717d97d5aa0a44061
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Thu Aug 10 10:20:39 2017 -0700

    add verification to gluon dataset (#7322)
    
    * add verification to gluon dataset
    
    * fix
    
    * rename variables
    
    * add tests
    
    * fix doc
---
 docs/api/python/gluon.md                        |  4 +--
 python/mxnet/gluon/data/vision.py               | 40 ++++++++++++++++++-------
 python/mxnet/gluon/model_zoo/model_store.py     | 25 +++++-----------
 python/mxnet/gluon/model_zoo/vision/__init__.py | 13 ++++++--
 python/mxnet/gluon/utils.py                     | 36 ++++++++++++++++++++--
 tests/python/unittest/test_gluon_data.py        |  4 +++
 6 files changed, 85 insertions(+), 37 deletions(-)

diff --git a/docs/api/python/gluon.md b/docs/api/python/gluon.md
index 6e213bb..ac63774 100644
--- a/docs/api/python/gluon.md
+++ b/docs/api/python/gluon.md
@@ -239,6 +239,7 @@ Model zoo provides pre-defined and pre-trained models to help bootstrap machine
 
 ```eval_rst
 .. currentmodule:: mxnet.gluon.model_zoo.vision
+.. automodule:: mxnet.gluon.model_zoo.vision
 ```
 
 ```eval_rst
@@ -508,8 +509,7 @@ Model zoo provides pre-defined and pre-trained models to help bootstrap machine
 .. automodule:: mxnet.gluon.data.vision
     :members:
 
-.. automodule:: mxnet.gluon.model_zoo.vision
-    :members:
+.. automethod:: mxnet.gluon.model_zoo.vision.get_model
 .. automethod:: mxnet.gluon.model_zoo.vision.resnet18_v1
 .. automethod:: mxnet.gluon.model_zoo.vision.resnet34_v1
 .. automethod:: mxnet.gluon.model_zoo.vision.resnet50_v1
diff --git a/python/mxnet/gluon/data/vision.py b/python/mxnet/gluon/data/vision.py
index 4ddbbbd..a16e736 100644
--- a/python/mxnet/gluon/data/vision.py
+++ b/python/mxnet/gluon/data/vision.py
@@ -26,7 +26,7 @@ import struct
 import numpy as np
 
 from . import dataset
-from ..utils import download
+from ..utils import download, check_sha1
 from ... import nd
 
 
@@ -67,7 +67,8 @@ class MNIST(_DownloadedDataset):
 
             transform=lambda data, label: (data.astype(np.float32)/255, label)
     """
-    def __init__(self, root, train=True, transform=lambda data, label: (data, label)):
+    def __init__(self, root='~/.mxnet/datasets/', train=True,
+                 transform=lambda data, label: (data, label)):
         super(MNIST, self).__init__(root, train, transform)
 
     def _get_data(self):
@@ -75,11 +76,15 @@ class MNIST(_DownloadedDataset):
             os.makedirs(self._root)
         url = 'http://data.mxnet.io/data/mnist/'
         if self._train:
-            data_file = download(url+'train-images-idx3-ubyte.gz', self._root)
-            label_file = download(url+'train-labels-idx1-ubyte.gz', self._root)
+            data_file = download(url+'train-images-idx3-ubyte.gz', self._root,
+                                 sha1_hash='6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d')
+            label_file = download(url+'train-labels-idx1-ubyte.gz', self._root,
+                                  sha1_hash='2a80914081dc54586dbdf242f9805a6b8d2a15fc')
         else:
-            data_file = download(url+'t10k-images-idx3-ubyte.gz', self._root)
-            label_file = download(url+'t10k-labels-idx1-ubyte.gz', self._root)
+            data_file = download(url+'t10k-images-idx3-ubyte.gz', self._root,
+                                 sha1_hash='c3a25af1f52dad7f726cce8cacb138654b760d48')
+            label_file = download(url+'t10k-labels-idx1-ubyte.gz', self._root,
+                                  sha1_hash='763e7fa3757d93b0cdec073cef058b2004252c17')
 
         with gzip.open(label_file, 'rb') as fin:
             struct.unpack(">II", fin.read(8))
@@ -110,7 +115,14 @@ class CIFAR10(_DownloadedDataset):
 
             transform=lambda data, label: (data.astype(np.float32)/255, label)
     """
-    def __init__(self, root, train=True, transform=lambda data, label: (data, label)):
+    def __init__(self, root='~/.mxnet/datasets/', train=True,
+                 transform=lambda data, label: (data, label)):
+        self._file_hashes = {'data_batch_1.bin': 'aadd24acce27caa71bf4b10992e9e7b2d74c2540',
+                             'data_batch_2.bin': 'c0ba65cce70568cd57b4e03e9ac8d2a5367c1795',
+                             'data_batch_3.bin': '1dd00a74ab1d17a6e7d73e185b69dbf31242f295',
+                             'data_batch_4.bin': 'aab85764eb3584312d3c7f65fd2fd016e36a258e',
+                             'data_batch_5.bin': '26e2849e66a845b7f1e4614ae70f4889ae604628',
+                             'test_batch.bin': '67eb016db431130d61cd03c7ad570b013799c88c'}
         super(CIFAR10, self).__init__(root, train, transform)
 
     def _read_batch(self, filename):
@@ -123,11 +135,17 @@ class CIFAR10(_DownloadedDataset):
     def _get_data(self):
         if not os.path.isdir(self._root):
             os.makedirs(self._root)
-        url = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
-        filename = download(url, self._root)
 
-        with tarfile.open(filename) as tar:
-            tar.extractall(self._root)
+        file_paths = [(name, os.path.join(self._root, 'cifar-10-batches-bin/', name))
+                      for name in self._file_hashes]
+        if any(not os.path.exists(path) or not check_sha1(path, self._file_hashes[name])
+               for name, path in file_paths):
+            url = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
+            filename = download(url, self._root,
+                                sha1_hash='e8aa088b9774a44ad217101d2e2569f823d2d491')
+
+            with tarfile.open(filename) as tar:
+                tar.extractall(self._root)
 
         if self._train:
             filename = os.path.join(self._root, 'cifar-10-batches-bin/data_batch_%d.bin')
diff --git a/python/mxnet/gluon/model_zoo/model_store.py b/python/mxnet/gluon/model_zoo/model_store.py
index e3c48ba..67ba572 100644
--- a/python/mxnet/gluon/model_zoo/model_store.py
+++ b/python/mxnet/gluon/model_zoo/model_store.py
@@ -19,11 +19,10 @@
 """Model zoo for pre-trained models."""
 from __future__ import print_function
 __all__ = ['get_model_file', 'purge']
-import hashlib
 import os
 import zipfile
 
-from ...test_utils import download
+from ..utils import download, check_sha1
 
 _model_sha1 = {name: checksum for checksum, name in [
     ('44335d1f0046b328243b32a26a4fbd62d9057b45', 'alexnet'),
@@ -56,21 +55,11 @@ def short_hash(name):
         raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
     return _model_sha1[name][:8]
 
-def verified(file_path, name):
-    sha1 = hashlib.sha1()
-    with open(file_path, 'rb') as f:
-        while True:
-            data = f.read(1048576)
-            if not data:
-                break
-            sha1.update(data)
-
-    return sha1.hexdigest() == _model_sha1[name]
-
 def get_model_file(name, local_dir=os.path.expanduser('~/.mxnet/models/')):
     r"""Return location for the pretrained on local file system.
 
     This function will download from online model zoo when model cannot be found or has mismatch.
+    The local_dir directory will be created if it doesn't exist.
 
     Parameters
     ----------
@@ -87,8 +76,9 @@ def get_model_file(name, local_dir=os.path.expanduser('~/.mxnet/models/')):
     file_name = '{name}-{short_hash}'.format(name=name,
                                              short_hash=short_hash(name))
     file_path = os.path.join(local_dir, file_name+'.params')
+    sha1_hash = _model_sha1[name]
     if os.path.exists(file_path):
-        if verified(file_path, name):
+        if check_sha1(file_path, sha1_hash):
             return file_path
         else:
             print('Mismatch in the content of model file detected. Downloading again.')
@@ -98,17 +88,16 @@ def get_model_file(name, local_dir=os.path.expanduser('~/.mxnet/models/')):
     if not os.path.exists(local_dir):
         os.makedirs(local_dir)
 
+    zip_file_path = os.path.join(local_dir, file_name+'.zip')
     download(_url_format.format(bucket=bucket,
                                 file_name=file_name),
-             fname=file_name+'.zip',
-             dirname=local_dir,
+             path=zip_file_path,
              overwrite=True)
-    zip_file_path = os.path.join(local_dir, file_name+'.zip')
     with zipfile.ZipFile(zip_file_path) as zf:
         zf.extractall(local_dir)
     os.remove(zip_file_path)
 
-    if verified(file_path, name):
+    if check_sha1(file_path, sha1_hash):
         return file_path
     else:
         raise ValueError('Downloaded file has different hash. Please try again.')
diff --git a/python/mxnet/gluon/model_zoo/vision/__init__.py b/python/mxnet/gluon/model_zoo/vision/__init__.py
index e4016db..354236b 100644
--- a/python/mxnet/gluon/model_zoo/vision/__init__.py
+++ b/python/mxnet/gluon/model_zoo/vision/__init__.py
@@ -18,6 +18,7 @@
 # coding: utf-8
 # pylint: disable=wildcard-import, arguments-differ
 r"""Module for pre-defined neural network models.
+
 This module contains definitions for the following model architectures:
 -  `AlexNet`_
 -  `DenseNet`_
@@ -26,21 +27,26 @@ This module contains definitions for the following model architectures:
 -  `ResNet V2`_
 -  `SqueezeNet`_
 -  `VGG`_
+
 You can construct a model with random weights by calling its constructor:
-.. code:: python
+.. code::
+
     import mxnet.gluon.models as models
     resnet18 = models.resnet18_v1()
     alexnet = models.alexnet()
     squeezenet = models.squeezenet1_0()
     densenet = models.densenet_161()
+
 We provide pre-trained models for all the models except ResNet V2.
 These can constructed by passing
 ``pretrained=True``:
-.. code:: python
+.. code::
+
     import mxnet.gluon.models as models
     resnet18 = models.resnet18_v1(pretrained=True)
     alexnet = models.alexnet(pretrained=True)
-Pretrained model is converted from torchvision.
+
+Pretrained models are converted from torchvision.
 All pre-trained models expect input images normalized in the same way,
 i.e. mini-batches of 3-channel RGB images of shape (N x 3 x H x W),
 where N is the batch size, and H and W are expected to be at least 224.
@@ -48,6 +54,7 @@ The images have to be loaded in to a range of [0, 1] and then normalized
 using ``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
 The transformation should preferrably happen at preprocessing. You can use
 ``mx.image.color_normalize`` for such transformation::
+
     image = image/255
     normalized = mx.image.color_normalize(image,
                                           mean=mx.nd.array([0.485, 0.456, 0.406]),
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index 7d9c378..cece22b 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -19,6 +19,7 @@
 # pylint: disable=
 """Parallelization utility optimizer."""
 import os
+import hashlib
 try:
     import requests
 except ImportError:
@@ -136,7 +137,33 @@ def _indent(s_, numSpaces):
     return s
 
 
-def download(url, path=None, overwrite=False):
+def check_sha1(filename, sha1_hash):
+    """Check whether the sha1 hash of the file content matches the expected hash.
+
+    Parameters
+    ----------
+    filename : str
+        Path to the file.
+    sha1_hash : str
+        Expected sha1 hash in hexadecimal digits.
+
+    Returns
+    -------
+    bool
+        Whether the file content matches the expected hash.
+    """
+    sha1 = hashlib.sha1()
+    with open(filename, 'rb') as f:
+        while True:
+            data = f.read(1048576)
+            if not data:
+                break
+            sha1.update(data)
+
+    return sha1.hexdigest() == sha1_hash
+
+
+def download(url, path=None, overwrite=False, sha1_hash=None):
     """Download an given URL
 
     Parameters
@@ -148,11 +175,14 @@ def download(url, path=None, overwrite=False):
         current directory with same name as in url.
     overwrite : bool, optional
         Whether to overwrite destination file if already exists.
+    sha1_hash : str, optional
+        Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
+        but doesn't match.
 
     Returns
     -------
     str
-        The filename of the downloaded file.
+        The file path of the downloaded file.
     """
     if path is None:
         fname = url.split('/')[-1]
@@ -161,7 +191,7 @@ def download(url, path=None, overwrite=False):
     else:
         fname = path
 
-    if overwrite or not os.path.exists(fname):
+    if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
         dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
         if not os.path.exists(dirname):
             os.makedirs(dirname)
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index da1de6b..e9a4301 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -67,6 +67,10 @@ def test_sampler():
     rand_batch_keep = gluon.data.BatchSampler(rand_sampler, 3, 'keep')
     assert sorted(sum(list(rand_batch_keep), [])) == list(range(10))
 
+def test_datasets():
+    assert len(gluon.data.vision.MNIST(root='data')) == 60000
+    assert len(gluon.data.vision.CIFAR10(root='data', train=False)) == 10000
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].