You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/10/12 22:15:51 UTC

[incubator-mxnet] branch master updated: Make Gluon download function to be atomic (#12572)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3eff8e8  Make Gluon download function to be atomic (#12572)
3eff8e8 is described below

commit 3eff8e8918bb1453a19793d803029537e22a1d3c
Author: Jake Lee <gs...@gmail.com>
AuthorDate: Fri Oct 12 15:15:39 2018 -0700

    Make Gluon download function to be atomic (#12572)
    
    * use rename trick to achieve atomic write but didn't support python2 and windows
    
    * add test for multiprocess download
    
    * implement atomic_replace referred by https://github.com/untitaker/python-atomicwrites
    
    * change the number of testing process to 10
    
    * add docstring and disable linter
    
    * half way to address some issue reviewer have
    
    * use warning instead of raise UserWarn
    
    * check for sha1
    
    * Trigger CI
    
    * fix the logic of checking hash
    
    * refine the error message
    
    * add more comments and expose the error message to the user
    
    * delete trailing whitespace
    
    * rename _path_to_encode to _str_to_unicode
    
    * fix the error message bug and add remove when the movefile fail on windows
    
    * add remove temp file for non-windows os
    
    * handle the OSError caused by os.remove
    
    * Trigger CI
    
    * use finally to raise failure of atomic replace
    
    * add missing try except block for os.remove
    
    * add retries value to error message
---
 python/mxnet/gluon/utils.py               | 98 +++++++++++++++++++++++++++----
 tests/python/unittest/test_gluon_utils.py | 46 +++++++++++++--
 2 files changed, 127 insertions(+), 17 deletions(-)

diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index d5a14a6..7832498 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -22,7 +22,9 @@ __all__ = ['split_data', 'split_and_load', 'clip_global_norm',
            'check_sha1', 'download']
 
 import os
+import sys
 import hashlib
+import uuid
 import warnings
 import collections
 import weakref
@@ -195,6 +197,62 @@ def check_sha1(filename, sha1_hash):
     return sha1.hexdigest() == sha1_hash
 
 
+if not sys.platform.startswith('win32'):
+    # refer to https://github.com/untitaker/python-atomicwrites
+    def _replace_atomic(src, dst):
+        """Implement atomic os.replace with linux and OSX. Internal use only"""
+        try:
+            os.rename(src, dst)
+        except OSError:
+            try:
+                os.remove(src)
+            except OSError:
+                pass
+            finally:
+                raise OSError(
+                    'Moving downloaded temp file - {}, to {} failed. \
+                    Please retry the download.'.format(src, dst))
+else:
+    import ctypes
+
+    _MOVEFILE_REPLACE_EXISTING = 0x1
+    # Setting this value guarantees that a move performed as a copy
+    # and delete operation is flushed to disk before the function returns.
+    # The flush occurs at the end of the copy operation.
+    _MOVEFILE_WRITE_THROUGH = 0x8
+    _windows_default_flags = _MOVEFILE_WRITE_THROUGH
+
+    text_type = unicode if sys.version_info[0] == 2 else str  # noqa
+
+    def _str_to_unicode(x):
+        """Handle text decoding. Internal use only"""
+        if not isinstance(x, text_type):
+            return x.decode(sys.getfilesystemencoding())
+        return x
+
+    def _handle_errors(rv, src):
+        """Handle WinError. Internal use only"""
+        if not rv:
+            msg = ctypes.FormatError(ctypes.GetLastError())
+            # if the MoveFileExW fails(e.g. fail to acquire file lock), removes the tempfile
+            try:
+                os.remove(src)
+            except OSError:
+                pass
+            finally:
+                raise OSError(msg)
+
+    def _replace_atomic(src, dst):
+        """Implement atomic os.replace with windows.
+        refer to https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw
+        The function fails when one of the process(copy, flush, delete) fails.
+        Internal use only"""
+        _handle_errors(ctypes.windll.kernel32.MoveFileExW(
+            _str_to_unicode(src), _str_to_unicode(dst),
+            _windows_default_flags | _MOVEFILE_REPLACE_EXISTING
+        ), src)
+
+
 def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
     """Download an given URL
 
@@ -231,7 +289,8 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
             fname = os.path.join(path, url.split('/')[-1])
         else:
             fname = path
-    assert retries >= 0, "Number of retries should be at least 0"
+    assert retries >= 0, "Number of retries should be at least 0, currently it's {}".format(
+        retries)
 
     if not verify_ssl:
         warnings.warn(
@@ -242,31 +301,48 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
         dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
         if not os.path.exists(dirname):
             os.makedirs(dirname)
-        while retries+1 > 0:
+        while retries + 1 > 0:
             # Disable pyling too broad Exception
             # pylint: disable=W0703
             try:
-                print('Downloading %s from %s...'%(fname, url))
+                print('Downloading {} from {}...'.format(fname, url))
                 r = requests.get(url, stream=True, verify=verify_ssl)
                 if r.status_code != 200:
-                    raise RuntimeError("Failed downloading url %s"%url)
-                with open(fname, 'wb') as f:
+                    raise RuntimeError('Failed downloading url {}'.format(url))
+                # create uuid for temporary files
+                random_uuid = str(uuid.uuid4())
+                with open('{}.{}'.format(fname, random_uuid), 'wb') as f:
                     for chunk in r.iter_content(chunk_size=1024):
                         if chunk: # filter out keep-alive new chunks
                             f.write(chunk)
+                # if the target file exists(created by other processes)
+                # and have the same hash with target file
+                # delete the temporary file
+                if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
+                    # atmoic operation in the same file system
+                    _replace_atomic('{}.{}'.format(fname, random_uuid), fname)
+                else:
+                    try:
+                        os.remove('{}.{}'.format(fname, random_uuid))
+                    except OSError:
+                        pass
+                    finally:
+                        warnings.warn(
+                            'File {} exists in file system so the downloaded file is deleted'.format(fname))
                 if sha1_hash and not check_sha1(fname, sha1_hash):
-                    raise UserWarning('File {} is downloaded but the content hash does not match.'\
-                                      ' The repo may be outdated or download may be incomplete. '\
-                                      'If the "repo_url" is overridden, consider switching to '\
-                                      'the default repo.'.format(fname))
+                    raise UserWarning(
+                        'File {} is downloaded but the content hash does not match.'
+                        ' The repo may be outdated or download may be incomplete. '
+                        'If the "repo_url" is overridden, consider switching to '
+                        'the default repo.'.format(fname))
                 break
             except Exception as e:
                 retries -= 1
                 if retries <= 0:
                     raise e
                 else:
-                    print("download failed, retrying, {} attempt{} left"
-                          .format(retries, 's' if retries > 1 else ''))
+                    print('download failed due to {}, retrying, {} attempt{} left'
+                          .format(repr(e), retries, 's' if retries > 1 else ''))
 
     return fname
 
diff --git a/tests/python/unittest/test_gluon_utils.py b/tests/python/unittest/test_gluon_utils.py
index 4318524..20f1c8c 100644
--- a/tests/python/unittest/test_gluon_utils.py
+++ b/tests/python/unittest/test_gluon_utils.py
@@ -19,6 +19,9 @@ import io
 import os
 import tempfile
 import warnings
+import glob
+import shutil
+import multiprocessing as mp
 
 try:
     from unittest import mock
@@ -46,15 +49,45 @@ def test_download_retries():
 
 @mock.patch(
     'requests.get',
-    mock.Mock(side_effect=
-              lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
+    mock.Mock(side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
+def _download_successful(tmp):
+    """ internal use for testing download successfully """
+    mx.gluon.utils.download(
+        "https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
+        path=tmp)
+
+
 def test_download_successful():
+    """ test download with one process """
     tmp = tempfile.mkdtemp()
     tmpfile = os.path.join(tmp, 'README.md')
-    mx.gluon.utils.download(
-        "https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
-        path=tmpfile)
-    assert os.path.getsize(tmpfile) > 100
+    _download_successful(tmpfile)
+    assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile)
+    pattern = os.path.join(tmp, 'README.md*')
+    # check only one file we want left
+    assert len(glob.glob(pattern)) == 1, glob.glob(pattern)
+    # delete temp dir
+    shutil.rmtree(tmp)
+
+
+def test_multiprocessing_download_successful():
+    """ test download with multiprocessing """
+    tmp = tempfile.mkdtemp()
+    tmpfile = os.path.join(tmp, 'README.md')
+    process_list = []
+    # test it with 10 processes
+    for i in range(10):
+        process_list.append(mp.Process(
+            target=_download_successful, args=(tmpfile,)))
+        process_list[i].start()
+    for i in range(10):
+        process_list[i].join()
+    assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile)
+    # check only one file we want left
+    pattern = os.path.join(tmp, 'README.md*')
+    assert len(glob.glob(pattern)) == 1, glob.glob(pattern)
+    # delete temp dir
+    shutil.rmtree(tmp)
 
 
 @mock.patch(
@@ -62,6 +95,7 @@ def test_download_successful():
     mock.Mock(
         side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT')))
 def test_download_ssl_verify():
+    """ test download verify_ssl parameter """
     with warnings.catch_warnings(record=True) as warnings_:
         mx.gluon.utils.download(
             "https://mxnet.incubator.apache.org/index.html", verify_ssl=False)