You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/14 00:46:14 UTC

[GitHub] stu1130 closed pull request #12466: [WIP] Fix race-conditon download function

stu1130 closed pull request #12466: [WIP] Fix race-conditon download function
URL: https://github.com/apache/incubator-mxnet/pull/12466
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index f04479d2371..eed26c6d9b5 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -22,6 +22,7 @@
            'check_sha1', 'download']
 
 import os
+import uuid
 import hashlib
 import warnings
 import collections
@@ -222,7 +223,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
         dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
         if not os.path.exists(dirname):
             os.makedirs(dirname)
-        while retries+1 > 0:
+        while retries + 1 > 0:
             # Disable pyling too broad Exception
             # pylint: disable=W0703
             try:
@@ -230,15 +231,25 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
                 r = requests.get(url, stream=True, verify=verify_ssl)
                 if r.status_code != 200:
                     raise RuntimeError("Failed downloading url %s"%url)
-                with open(fname, 'wb') as f:
+                # create uuid for temporary files
+                random_uuid = str(uuid.uuid4())
+                with open('{}.{}'.format(fname, random_uuid), 'wb') as f:
                     for chunk in r.iter_content(chunk_size=1024):
                         if chunk: # filter out keep-alive new chunks
                             f.write(chunk)
-                if sha1_hash and not check_sha1(fname, sha1_hash):
-                    raise UserWarning('File {} is downloaded but the content hash does not match.'\
-                                      ' The repo may be outdated or download may be incomplete. '\
-                                      'If the "repo_url" is overridden, consider switching to '\
-                                      'the default repo.'.format(fname))
+                    # if the target file exists(created by other processes),
+                    # delete the temporary file
+                    if os.path.exists(fname):
+                        os.remove('{}.{}'.format(fname, random_uuid))
+                    else:
+                        # atmoic operation in the same file system
+                        os.replace('{}.{}'.format(fname, random_uuid), fname)
+                    if sha1_hash and not check_sha1(fname, sha1_hash):
+                        raise UserWarning(
+                            'File {} is downloaded but the content hash does not match.'
+                            ' The repo may be outdated or download may be incomplete. '
+                            'If the "repo_url" is overridden, consider switching to '
+                            'the default repo.'.format(fname))
                 break
             except Exception as e:
                 retries -= 1
diff --git a/tests/python/unittest/test_gluon_utils.py b/tests/python/unittest/test_gluon_utils.py
index 431852427f5..5f084ef6854 100644
--- a/tests/python/unittest/test_gluon_utils.py
+++ b/tests/python/unittest/test_gluon_utils.py
@@ -19,15 +19,16 @@
 import os
 import tempfile
 import warnings
-
+import glob
+import shutil
+import multiprocessing as mp
 try:
     from unittest import mock
 except ImportError:
     import mock
-import mxnet as mx
 import requests
 from nose.tools import raises
-
+import mxnet as mx
 
 class MockResponse(requests.Response):
     def __init__(self, status_code, content):
@@ -46,15 +47,45 @@ def test_download_retries():
 
 @mock.patch(
     'requests.get',
-    mock.Mock(side_effect=
-              lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
+    mock.Mock(side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
+def _download_successful(tmp):
+    """ internal use for testing download successfully """
+    mx.gluon.utils.download(
+        "https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
+        path=tmp)
+
+
 def test_download_successful():
+    """ test download with one process """
     tmp = tempfile.mkdtemp()
     tmpfile = os.path.join(tmp, 'README.md')
-    mx.gluon.utils.download(
-        "https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
-        path=tmpfile)
-    assert os.path.getsize(tmpfile) > 100
+    _download_successful(tmpfile)
+    assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile)
+    pattern = os.path.join(tmp, 'README.md*')
+    # check only one file we want left
+    assert len(glob.glob(pattern)) == 1, glob.glob(pattern)
+    # delete temp dir
+    shutil.rmtree(tmp)
+
+
+def test_multiprocessing_download_successful():
+    """ test download with multiprocessing """
+    tmp = tempfile.mkdtemp()
+    tmpfile = os.path.join(tmp, 'README.md')
+    process_list = []
+    # test it with 10 processes
+    for i in range(10):
+        process_list.append(mp.Process(
+            target=_download_successful, args=(tmpfile,)))
+        process_list[i].start()
+    for i in range(10):
+        process_list[i].join()
+    assert os.path.getsize(tmpfile) > 100, os.path.getsize(tmpfile)
+    # check only one file we want left
+    pattern = os.path.join(tmp, 'README.md*')
+    assert len(glob.glob(pattern)) == 1, glob.glob(pattern)
+    # delete temp dir
+    shutil.rmtree(tmp)
 
 
 @mock.patch(
@@ -62,6 +93,7 @@ def test_download_successful():
     mock.Mock(
         side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT')))
 def test_download_ssl_verify():
+    """ test download verify_ssl parameter """
     with warnings.catch_warnings(record=True) as warnings_:
         mx.gluon.utils.download(
             "https://mxnet.incubator.apache.org/index.html", verify_ssl=False)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services