You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/10 17:20:42 UTC
[incubator-mxnet] branch master updated: add verification to gluon
dataset (#7322)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 0d1407f add verification to gluon dataset (#7322)
0d1407f is described below
commit 0d1407fb69c1f3a71ef6c8d717d97d5aa0a44061
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Thu Aug 10 10:20:39 2017 -0700
add verification to gluon dataset (#7322)
* add verification to gluon dataset
* fix
* rename variables
* add tests
* fix doc
---
docs/api/python/gluon.md | 4 +--
python/mxnet/gluon/data/vision.py | 40 ++++++++++++++++++-------
python/mxnet/gluon/model_zoo/model_store.py | 25 +++++-----------
python/mxnet/gluon/model_zoo/vision/__init__.py | 13 ++++++--
python/mxnet/gluon/utils.py | 36 ++++++++++++++++++++--
tests/python/unittest/test_gluon_data.py | 4 +++
6 files changed, 85 insertions(+), 37 deletions(-)
diff --git a/docs/api/python/gluon.md b/docs/api/python/gluon.md
index 6e213bb..ac63774 100644
--- a/docs/api/python/gluon.md
+++ b/docs/api/python/gluon.md
@@ -239,6 +239,7 @@ Model zoo provides pre-defined and pre-trained models to help bootstrap machine
```eval_rst
.. currentmodule:: mxnet.gluon.model_zoo.vision
+.. automodule:: mxnet.gluon.model_zoo.vision
```
```eval_rst
@@ -508,8 +509,7 @@ Model zoo provides pre-defined and pre-trained models to help bootstrap machine
.. automodule:: mxnet.gluon.data.vision
:members:
-.. automodule:: mxnet.gluon.model_zoo.vision
- :members:
+.. automethod:: mxnet.gluon.model_zoo.vision.get_model
.. automethod:: mxnet.gluon.model_zoo.vision.resnet18_v1
.. automethod:: mxnet.gluon.model_zoo.vision.resnet34_v1
.. automethod:: mxnet.gluon.model_zoo.vision.resnet50_v1
diff --git a/python/mxnet/gluon/data/vision.py b/python/mxnet/gluon/data/vision.py
index 4ddbbbd..a16e736 100644
--- a/python/mxnet/gluon/data/vision.py
+++ b/python/mxnet/gluon/data/vision.py
@@ -26,7 +26,7 @@ import struct
import numpy as np
from . import dataset
-from ..utils import download
+from ..utils import download, check_sha1
from ... import nd
@@ -67,7 +67,8 @@ class MNIST(_DownloadedDataset):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
- def __init__(self, root, train=True, transform=lambda data, label: (data, label)):
+ def __init__(self, root='~/.mxnet/datasets/', train=True,
+ transform=lambda data, label: (data, label)):
super(MNIST, self).__init__(root, train, transform)
def _get_data(self):
@@ -75,11 +76,15 @@ class MNIST(_DownloadedDataset):
os.makedirs(self._root)
url = 'http://data.mxnet.io/data/mnist/'
if self._train:
- data_file = download(url+'train-images-idx3-ubyte.gz', self._root)
- label_file = download(url+'train-labels-idx1-ubyte.gz', self._root)
+ data_file = download(url+'train-images-idx3-ubyte.gz', self._root,
+ sha1_hash='6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d')
+ label_file = download(url+'train-labels-idx1-ubyte.gz', self._root,
+ sha1_hash='2a80914081dc54586dbdf242f9805a6b8d2a15fc')
else:
- data_file = download(url+'t10k-images-idx3-ubyte.gz', self._root)
- label_file = download(url+'t10k-labels-idx1-ubyte.gz', self._root)
+ data_file = download(url+'t10k-images-idx3-ubyte.gz', self._root,
+ sha1_hash='c3a25af1f52dad7f726cce8cacb138654b760d48')
+ label_file = download(url+'t10k-labels-idx1-ubyte.gz', self._root,
+ sha1_hash='763e7fa3757d93b0cdec073cef058b2004252c17')
with gzip.open(label_file, 'rb') as fin:
struct.unpack(">II", fin.read(8))
@@ -110,7 +115,14 @@ class CIFAR10(_DownloadedDataset):
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
- def __init__(self, root, train=True, transform=lambda data, label: (data, label)):
+ def __init__(self, root='~/.mxnet/datasets/', train=True,
+ transform=lambda data, label: (data, label)):
+ self._file_hashes = {'data_batch_1.bin': 'aadd24acce27caa71bf4b10992e9e7b2d74c2540',
+ 'data_batch_2.bin': 'c0ba65cce70568cd57b4e03e9ac8d2a5367c1795',
+ 'data_batch_3.bin': '1dd00a74ab1d17a6e7d73e185b69dbf31242f295',
+ 'data_batch_4.bin': 'aab85764eb3584312d3c7f65fd2fd016e36a258e',
+ 'data_batch_5.bin': '26e2849e66a845b7f1e4614ae70f4889ae604628',
+ 'test_batch.bin': '67eb016db431130d61cd03c7ad570b013799c88c'}
super(CIFAR10, self).__init__(root, train, transform)
def _read_batch(self, filename):
@@ -123,11 +135,17 @@ class CIFAR10(_DownloadedDataset):
def _get_data(self):
if not os.path.isdir(self._root):
os.makedirs(self._root)
- url = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
- filename = download(url, self._root)
- with tarfile.open(filename) as tar:
- tar.extractall(self._root)
+ file_paths = [(name, os.path.join(self._root, 'cifar-10-batches-bin/', name))
+ for name in self._file_hashes]
+ if any(not os.path.exists(path) or not check_sha1(path, self._file_hashes[name])
+ for name, path in file_paths):
+ url = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
+ filename = download(url, self._root,
+ sha1_hash='e8aa088b9774a44ad217101d2e2569f823d2d491')
+
+ with tarfile.open(filename) as tar:
+ tar.extractall(self._root)
if self._train:
filename = os.path.join(self._root, 'cifar-10-batches-bin/data_batch_%d.bin')
diff --git a/python/mxnet/gluon/model_zoo/model_store.py b/python/mxnet/gluon/model_zoo/model_store.py
index e3c48ba..67ba572 100644
--- a/python/mxnet/gluon/model_zoo/model_store.py
+++ b/python/mxnet/gluon/model_zoo/model_store.py
@@ -19,11 +19,10 @@
"""Model zoo for pre-trained models."""
from __future__ import print_function
__all__ = ['get_model_file', 'purge']
-import hashlib
import os
import zipfile
-from ...test_utils import download
+from ..utils import download, check_sha1
_model_sha1 = {name: checksum for checksum, name in [
('44335d1f0046b328243b32a26a4fbd62d9057b45', 'alexnet'),
@@ -56,21 +55,11 @@ def short_hash(name):
raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
return _model_sha1[name][:8]
-def verified(file_path, name):
- sha1 = hashlib.sha1()
- with open(file_path, 'rb') as f:
- while True:
- data = f.read(1048576)
- if not data:
- break
- sha1.update(data)
-
- return sha1.hexdigest() == _model_sha1[name]
-
def get_model_file(name, local_dir=os.path.expanduser('~/.mxnet/models/')):
r"""Return location for the pretrained on local file system.
This function will download from online model zoo when model cannot be found or has mismatch.
+ The local_dir directory will be created if it doesn't exist.
Parameters
----------
@@ -87,8 +76,9 @@ def get_model_file(name, local_dir=os.path.expanduser('~/.mxnet/models/')):
file_name = '{name}-{short_hash}'.format(name=name,
short_hash=short_hash(name))
file_path = os.path.join(local_dir, file_name+'.params')
+ sha1_hash = _model_sha1[name]
if os.path.exists(file_path):
- if verified(file_path, name):
+ if check_sha1(file_path, sha1_hash):
return file_path
else:
print('Mismatch in the content of model file detected. Downloading again.')
@@ -98,17 +88,16 @@ def get_model_file(name, local_dir=os.path.expanduser('~/.mxnet/models/')):
if not os.path.exists(local_dir):
os.makedirs(local_dir)
+ zip_file_path = os.path.join(local_dir, file_name+'.zip')
download(_url_format.format(bucket=bucket,
file_name=file_name),
- fname=file_name+'.zip',
- dirname=local_dir,
+ path=zip_file_path,
overwrite=True)
- zip_file_path = os.path.join(local_dir, file_name+'.zip')
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(local_dir)
os.remove(zip_file_path)
- if verified(file_path, name):
+ if check_sha1(file_path, sha1_hash):
return file_path
else:
raise ValueError('Downloaded file has different hash. Please try again.')
diff --git a/python/mxnet/gluon/model_zoo/vision/__init__.py b/python/mxnet/gluon/model_zoo/vision/__init__.py
index e4016db..354236b 100644
--- a/python/mxnet/gluon/model_zoo/vision/__init__.py
+++ b/python/mxnet/gluon/model_zoo/vision/__init__.py
@@ -18,6 +18,7 @@
# coding: utf-8
# pylint: disable=wildcard-import, arguments-differ
r"""Module for pre-defined neural network models.
+
This module contains definitions for the following model architectures:
- `AlexNet`_
- `DenseNet`_
@@ -26,21 +27,26 @@ This module contains definitions for the following model architectures:
- `ResNet V2`_
- `SqueezeNet`_
- `VGG`_
+
You can construct a model with random weights by calling its constructor:
-.. code:: python
+.. code::
+
import mxnet.gluon.models as models
resnet18 = models.resnet18_v1()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()
+
We provide pre-trained models for all the models except ResNet V2.
These can constructed by passing
``pretrained=True``:
-.. code:: python
+.. code::
+
import mxnet.gluon.models as models
resnet18 = models.resnet18_v1(pretrained=True)
alexnet = models.alexnet(pretrained=True)
-Pretrained model is converted from torchvision.
+
+Pretrained models are converted from torchvision.
All pre-trained models expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB images of shape (N x 3 x H x W),
where N is the batch size, and H and W are expected to be at least 224.
@@ -48,6 +54,7 @@ The images have to be loaded in to a range of [0, 1] and then normalized
using ``mean = [0.485, 0.456, 0.406]`` and ``std = [0.229, 0.224, 0.225]``.
The transformation should preferrably happen at preprocessing. You can use
``mx.image.color_normalize`` for such transformation::
+
image = image/255
normalized = mx.image.color_normalize(image,
mean=mx.nd.array([0.485, 0.456, 0.406]),
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index 7d9c378..cece22b 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -19,6 +19,7 @@
# pylint: disable=
"""Parallelization utility optimizer."""
import os
+import hashlib
try:
import requests
except ImportError:
@@ -136,7 +137,33 @@ def _indent(s_, numSpaces):
return s
-def download(url, path=None, overwrite=False):
+def check_sha1(filename, sha1_hash):
+ """Check whether the sha1 hash of the file content matches the expected hash.
+
+ Parameters
+ ----------
+ filename : str
+ Path to the file.
+ sha1_hash : str
+ Expected sha1 hash in hexadecimal digits.
+
+ Returns
+ -------
+ bool
+ Whether the file content matches the expected hash.
+ """
+ sha1 = hashlib.sha1()
+ with open(filename, 'rb') as f:
+ while True:
+ data = f.read(1048576)
+ if not data:
+ break
+ sha1.update(data)
+
+ return sha1.hexdigest() == sha1_hash
+
+
+def download(url, path=None, overwrite=False, sha1_hash=None):
"""Download an given URL
Parameters
@@ -148,11 +175,14 @@ def download(url, path=None, overwrite=False):
current directory with same name as in url.
overwrite : bool, optional
Whether to overwrite destination file if already exists.
+ sha1_hash : str, optional
+ Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
+ but doesn't match.
Returns
-------
str
- The filename of the downloaded file.
+ The file path of the downloaded file.
"""
if path is None:
fname = url.split('/')[-1]
@@ -161,7 +191,7 @@ def download(url, path=None, overwrite=False):
else:
fname = path
- if overwrite or not os.path.exists(fname):
+ if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index da1de6b..e9a4301 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -67,6 +67,10 @@ def test_sampler():
rand_batch_keep = gluon.data.BatchSampler(rand_sampler, 3, 'keep')
assert sorted(sum(list(rand_batch_keep), [])) == list(range(10))
+def test_datasets():
+ assert len(gluon.data.vision.MNIST(root='data')) == 60000
+ assert len(gluon.data.vision.CIFAR10(root='data', train=False)) == 10000
+
if __name__ == '__main__':
import nose
nose.runmodule()
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].