You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/10/12 22:15:51 UTC
[incubator-mxnet] branch master updated: Make Gluon download
function to be atomic (#12572)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 3eff8e8 Make Gluon download function to be atomic (#12572)
3eff8e8 is described below
commit 3eff8e8918bb1453a19793d803029537e22a1d3c
Author: Jake Lee <gs...@gmail.com>
AuthorDate: Fri Oct 12 15:15:39 2018 -0700
Make Gluon download function to be atomic (#12572)
* use rename trick to achieve atomic write but didn't support python2 and windows
* add test for multiprocess download
* implement atomic_replace referred by https://github.com/untitaker/python-atomicwrites
* change the number of testing process to 10
* add docstring and disable linter
* half way to address some issue reviewer have
* use warning instead of raise UserWarn
* check for sha1
* Trigger CI
* fix the logic of checking hash
* refine the error message
* add more comments and expose the error message to the user
* delete trailing whitespace
* rename _path_to_encode to _str_to_unicode
* fix the error message bug and add remove when the movefile fail on windows
* add remove temp file for non-windows os
* handle the OSError caused by os.remove
* Trigger CI
* use finally to raise failure of atomic replace
* add missing try except block for os.remove
* add retries value to error message
---
python/mxnet/gluon/utils.py | 98 +++++++++++++++++++++++++++----
tests/python/unittest/test_gluon_utils.py | 46 +++++++++++++--
2 files changed, 127 insertions(+), 17 deletions(-)
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index d5a14a6..7832498 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -22,7 +22,9 @@ __all__ = ['split_data', 'split_and_load', 'clip_global_norm',
'check_sha1', 'download']
import os
+import sys
import hashlib
+import uuid
import warnings
import collections
import weakref
@@ -195,6 +197,62 @@ def check_sha1(filename, sha1_hash):
return sha1.hexdigest() == sha1_hash
+if not sys.platform.startswith('win32'):
+ # refer to https://github.com/untitaker/python-atomicwrites
+ def _replace_atomic(src, dst):
+ """Implement atomic os.replace with linux and OSX. Internal use only"""
+ try:
+ os.rename(src, dst)
+ except OSError:
+ try:
+ os.remove(src)
+ except OSError:
+ pass
+ finally:
+ raise OSError(
+ 'Moving downloaded temp file - {}, to {} failed. \
+ Please retry the download.'.format(src, dst))
+else:
+ import ctypes
+
+ _MOVEFILE_REPLACE_EXISTING = 0x1
+ # Setting this value guarantees that a move performed as a copy
+ # and delete operation is flushed to disk before the function returns.
+ # The flush occurs at the end of the copy operation.
+ _MOVEFILE_WRITE_THROUGH = 0x8
+ _windows_default_flags = _MOVEFILE_WRITE_THROUGH
+
+ text_type = unicode if sys.version_info[0] == 2 else str # noqa
+
+ def _str_to_unicode(x):
+ """Handle text decoding. Internal use only"""
+ if not isinstance(x, text_type):
+ return x.decode(sys.getfilesystemencoding())
+ return x
+
+ def _handle_errors(rv, src):
+ """Handle WinError. Internal use only"""
+ if not rv:
+ msg = ctypes.FormatError(ctypes.GetLastError())
+ # if the MoveFileExW fails(e.g. fail to acquire file lock), removes the tempfile
+ try:
+ os.remove(src)
+ except OSError:
+ pass
+ finally:
+ raise OSError(msg)
+
+ def _replace_atomic(src, dst):
+ """Implement atomic os.replace with windows.
+ refer to https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw
+ The function fails when one of the process(copy, flush, delete) fails.
+ Internal use only"""
+ _handle_errors(ctypes.windll.kernel32.MoveFileExW(
+ _str_to_unicode(src), _str_to_unicode(dst),
+ _windows_default_flags | _MOVEFILE_REPLACE_EXISTING
+ ), src)
+
+
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""Download an given URL
@@ -231,7 +289,8 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
fname = os.path.join(path, url.split('/')[-1])
else:
fname = path
- assert retries >= 0, "Number of retries should be at least 0"
+ assert retries >= 0, "Number of retries should be at least 0, currently it's {}".format(
+ retries)
if not verify_ssl:
warnings.warn(
@@ -242,31 +301,48 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
- while retries+1 > 0:
+ while retries + 1 > 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
- print('Downloading %s from %s...'%(fname, url))
+ print('Downloading {} from {}...'.format(fname, url))
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
- raise RuntimeError("Failed downloading url %s"%url)
- with open(fname, 'wb') as f:
+ raise RuntimeError('Failed downloading url {}'.format(url))
+ # create uuid for temporary files
+ random_uuid = str(uuid.uuid4())
+ with open('{}.{}'.format(fname, random_uuid), 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
+ # if the target file exists(created by other processes)
+ # and have the same hash with target file
+ # delete the temporary file
+ if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
+ # atmoic operation in the same file system
+ _replace_atomic('{}.{}'.format(fname, random_uuid), fname)
+ else:
+ try:
+ os.remove('{}.{}'.format(fname, random_uuid))
+ except OSError:
+ pass
+ finally:
+ warnings.warn(
+ 'File {} exists in file system so the downloaded file is deleted'.format(fname))
if sha1_hash and not check_sha1(fname, sha1_hash):
- raise UserWarning('File {} is downloaded but the content hash does not match.'\
- ' The repo may be outdated or download may be incomplete. '\
- 'If the "repo_url" is overridden, consider switching to '\
- 'the default repo.'.format(fname))
+ raise UserWarning(
+ 'File {} is downloaded but the content hash does not match.'
+ ' The repo may be outdated or download may be incomplete. '
+ 'If the "repo_url" is overridden, consider switching to '
+ 'the default repo.'.format(fname))
break
except Exception as e:
retries -= 1
if retries <= 0:
raise e
else:
- print("download failed, retrying, {} attempt{} left"
- .format(retries, 's' if retries > 1 else ''))
+ print('download failed due to {}, retrying, {} attempt{} left'
+ .format(repr(e), retries, 's' if retries > 1 else ''))
return fname
diff --git a/tests/python/unittest/test_gluon_utils.py b/tests/python/unittest/test_gluon_utils.py
index 4318524..20f1c8c 100644
--- a/tests/python/unittest/test_gluon_utils.py
+++ b/tests/python/unittest/test_gluon_utils.py
@@ -19,6 +19,9 @@ import io
import os
import tempfile
import warnings
+import glob
+import shutil
+import multiprocessing as mp
try:
from unittest import mock
@@ -46,15 +49,45 @@ def test_download_retries():
@mock.patch(
'requests.get',
- mock.Mock(side_effect=
- lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
+ mock.Mock(side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
+def _download_successful(tmp):
+ """ internal use for testing download successfully """
+ mx.gluon.utils.download(
+ "https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
+ path=tmp)
+
+
def test_download_successful():
+ """ test download with one process """
tmp = tempfile.mkdtemp()
tmpfile = os.path.join(tmp, 'README.md')
- mx.gluon.utils.download(
- "https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
- path=tmpfile)
- assert os.path.getsize(tmpfile) > 100
+ _download_successful(tmpfile)
+ assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile)
+ pattern = os.path.join(tmp, 'README.md*')
+ # check only one file we want left
+ assert len(glob.glob(pattern)) == 1, glob.glob(pattern)
+ # delete temp dir
+ shutil.rmtree(tmp)
+
+
+def test_multiprocessing_download_successful():
+ """ test download with multiprocessing """
+ tmp = tempfile.mkdtemp()
+ tmpfile = os.path.join(tmp, 'README.md')
+ process_list = []
+ # test it with 10 processes
+ for i in range(10):
+ process_list.append(mp.Process(
+ target=_download_successful, args=(tmpfile,)))
+ process_list[i].start()
+ for i in range(10):
+ process_list[i].join()
+ assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile)
+ # check only one file we want left
+ pattern = os.path.join(tmp, 'README.md*')
+ assert len(glob.glob(pattern)) == 1, glob.glob(pattern)
+ # delete temp dir
+ shutil.rmtree(tmp)
@mock.patch(
@@ -62,6 +95,7 @@ def test_download_successful():
mock.Mock(
side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT')))
def test_download_ssl_verify():
+ """ test download verify_ssl parameter """
with warnings.catch_warnings(record=True) as warnings_:
mx.gluon.utils.download(
"https://mxnet.incubator.apache.org/index.html", verify_ssl=False)