You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2018/06/08 01:13:04 UTC

[incubator-mxnet] branch master updated: [MXNET-525] Add retry logic to download functions to fix flaky tests (#11181)

This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b434b8e  [MXNET-525] Add retry logic to download functions to fix flaky tests (#11181)
b434b8e is described below

commit b434b8ec18f774c99b0830bd3ca66859212b4911
Author: Thomas Delteil <th...@gmail.com>
AuthorDate: Thu Jun 7 18:12:58 2018 -0700

    [MXNET-525] Add retry logic to download functions to fix flaky tests (#11181)
    
    * Adding retry logic to the download function
    
    * Adding retry logic on download
    
    * Fixing retry wording
    
    * addressing feedback
    
    * forgot parenthesis
---
 python/mxnet/gluon/utils.py               | 45 ++++++++++++++++++++-----------
 python/mxnet/test_utils.py                | 31 ++++++++++++++++-----
 tests/python/unittest/test_gluon_utils.py | 34 +++++++++++++++++++++++
 tests/python/unittest/test_test_utils.py  | 34 +++++++++++++++++++++++
 4 files changed, 121 insertions(+), 23 deletions(-)

diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index 818aa3d..06b91fa 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -171,7 +171,7 @@ def check_sha1(filename, sha1_hash):
     return sha1.hexdigest() == sha1_hash
 
 
-def download(url, path=None, overwrite=False, sha1_hash=None):
+def download(url, path=None, overwrite=False, sha1_hash=None, retries=5):
     """Download an given URL
 
     Parameters
@@ -186,6 +186,8 @@ def download(url, path=None, overwrite=False, sha1_hash=None):
     sha1_hash : str, optional
         Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
         but doesn't match.
+    retries : integer, default 5
+        The number of times to attempt the download in case of failure or non 200 return codes
 
     Returns
     -------
@@ -200,26 +202,37 @@ def download(url, path=None, overwrite=False, sha1_hash=None):
             fname = os.path.join(path, url.split('/')[-1])
         else:
             fname = path
+    assert retries >= 0, "Number of retries should be at least 0"
 
     if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
         dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
         if not os.path.exists(dirname):
             os.makedirs(dirname)
-
-        print('Downloading %s from %s...'%(fname, url))
-        r = requests.get(url, stream=True)
-        if r.status_code != 200:
-            raise RuntimeError("Failed downloading url %s"%url)
-        with open(fname, 'wb') as f:
-            for chunk in r.iter_content(chunk_size=1024):
-                if chunk: # filter out keep-alive new chunks
-                    f.write(chunk)
-
-        if sha1_hash and not check_sha1(fname, sha1_hash):
-            raise UserWarning('File {} is downloaded but the content hash does not match. ' \
-                              'The repo may be outdated or download may be incomplete. ' \
-                              'If the "repo_url" is overridden, consider switching to ' \
-                              'the default repo.'.format(fname))
+        while retries+1 > 0:
+            # Disable pyling too broad Exception
+            # pylint: disable=W0703
+            try:
+                print('Downloading %s from %s...'%(fname, url))
+                r = requests.get(url, stream=True)
+                if r.status_code != 200:
+                    raise RuntimeError("Failed downloading url %s"%url)
+                with open(fname, 'wb') as f:
+                    for chunk in r.iter_content(chunk_size=1024):
+                        if chunk: # filter out keep-alive new chunks
+                            f.write(chunk)
+                if sha1_hash and not check_sha1(fname, sha1_hash):
+                    raise UserWarning('File {} is downloaded but the content hash does not match.'\
+                                      ' The repo may be outdated or download may be incomplete. '\
+                                      'If the "repo_url" is overridden, consider switching to '\
+                                      'the default repo.'.format(fname))
+                break
+            except Exception as e:
+                retries -= 1
+                if retries <= 0:
+                    raise e
+                else:
+                    print("download failed, retrying, {} attempt{} left"
+                          .format(retries, 's' if retries > 1 else ''))
 
     return fname
 
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index bcdcc9c..686802d 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -1367,7 +1367,7 @@ def list_gpus():
             pass
     return range(len([i for i in re.split('\n') if 'GPU' in i]))
 
-def download(url, fname=None, dirname=None, overwrite=False):
+def download(url, fname=None, dirname=None, overwrite=False, retries=5):
     """Download an given URL
 
     Parameters
@@ -1385,12 +1385,17 @@ def download(url, fname=None, dirname=None, overwrite=False):
         Default is false, which means skipping download if the local file
         exists. If true, then download the url to overwrite the local file if
         exists.
+    retries : integer, default 5
+        The number of times to attempt the download in case of failure or non 200 return codes
 
     Returns
     -------
     str
         The filename of the downloaded file
     """
+
+    assert retries >= 0, "Number of retries should be at least 0"
+
     if fname is None:
         fname = url.split('/')[-1]
 
@@ -1411,12 +1416,24 @@ def download(url, fname=None, dirname=None, overwrite=False):
         logging.info("%s exists, skipping download", fname)
         return fname
 
-    r = requests.get(url, stream=True)
-    assert r.status_code == 200, "failed to open %s" % url
-    with open(fname, 'wb') as f:
-        for chunk in r.iter_content(chunk_size=1024):
-            if chunk: # filter out keep-alive new chunks
-                f.write(chunk)
+    while retries+1 > 0:
+        # Disable pyling too broad Exception
+        # pylint: disable=W0703
+        try:
+            r = requests.get(url, stream=True)
+            assert r.status_code == 200, "failed to open %s" % url
+            with open(fname, 'wb') as f:
+                for chunk in r.iter_content(chunk_size=1024):
+                    if chunk: # filter out keep-alive new chunks
+                        f.write(chunk)
+                break
+        except Exception as e:
+            retries -= 1
+            if retries <= 0:
+                raise e
+            else:
+                print("download failed, retrying, {} attempt{} left"
+                      .format(retries, 's' if retries > 1 else ''))
     logging.info("downloaded %s into %s successfully", url, fname)
     return fname
 
diff --git a/tests/python/unittest/test_gluon_utils.py b/tests/python/unittest/test_gluon_utils.py
new file mode 100644
index 0000000..a5d3b14
--- /dev/null
+++ b/tests/python/unittest/test_gluon_utils.py
@@ -0,0 +1,34 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import tempfile
+
+import mxnet as mx
+from nose.tools import *
+
+
+@raises(Exception)
+def test_download_retries():
+    mx.gluon.utils.download("http://doesnotexist.notfound")
+
+def test_download_successful():
+    tmp = tempfile.mkdtemp()
+    tmpfile = os.path.join(tmp, 'README.md')
+    mx.gluon.utils.download("https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
+                            path=tmpfile)
+    assert os.path.getsize(tmpfile) > 100
\ No newline at end of file
diff --git a/tests/python/unittest/test_test_utils.py b/tests/python/unittest/test_test_utils.py
new file mode 100644
index 0000000..49f0b93
--- /dev/null
+++ b/tests/python/unittest/test_test_utils.py
@@ -0,0 +1,34 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import tempfile
+
+import mxnet as mx
+from nose.tools import *
+
+
+@raises(Exception)
+def test_download_retries():
+    mx.test_utils.download("http://doesnotexist.notfound")
+
+def test_download_successful():
+    tmp = tempfile.mkdtemp()
+    tmpfile = os.path.join(tmp, 'README.md')
+    mx.test_utils.download("https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
+                           fname=tmpfile)
+    assert os.path.getsize(tmpfile) > 100
\ No newline at end of file

-- 
To stop receiving notification emails like this one, please contact
marcoabreu@apache.org.