You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2018/11/12 16:21:23 UTC

[incubator-mxnet] branch master updated: Refactor kvstore test (#13140)

This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d8d2d6e  Refactor kvstore test (#13140)
d8d2d6e is described below

commit d8d2d6ef3d688a465e47f7170c2a11da804c2835
Author: Pedro Larroy <92...@users.noreply.github.com>
AuthorDate: Mon Nov 12 17:21:11 2018 +0100

    Refactor kvstore test (#13140)
    
    * Refactor kvstore test
    
    * Fix pylint
    
    * Fix problem with some OSX not handling the cast on imDecode (#13207)
    
    * Fix num_gpus
---
 python/mxnet/test_utils.py           | 17 +++++++++++++
 tests/python/gpu/test_device.py      | 46 ++++++++++++------------------------
 tests/python/gpu/test_kvstore_gpu.py | 18 +-------------
 3 files changed, 33 insertions(+), 48 deletions(-)

diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 38a2733..7ac63c6 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -1998,3 +1998,20 @@ def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='defa
     if compare_states:
         compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
     assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
+
+class EnvManager(object):
+    """Environment variable setter and unsetter via with idiom"""
+    def __init__(self, key, val):
+        self._key = key
+        self._next_val = val
+        self._prev_val = None
+
+    def __enter__(self):
+        self._prev_val = os.environ.get(self._key)
+        os.environ[self._key] = self._next_val
+
+    def __exit__(self, ptype, value, trace):
+        if self._prev_val:
+            os.environ[self._key] = self._prev_val
+        else:
+            del os.environ[self._key]
diff --git a/tests/python/gpu/test_device.py b/tests/python/gpu/test_device.py
index 66772dc..cd8145c 100644
--- a/tests/python/gpu/test_device.py
+++ b/tests/python/gpu/test_device.py
@@ -19,35 +19,23 @@ import mxnet as mx
 import numpy as np
 import unittest
 import os
+import logging
+
+from mxnet.test_utils import EnvManager
 
 shapes = [(10), (100), (1000), (10000), (100000), (2,2), (2,3,4,5,6,7,8)]
 keys = [1,2,3,4,5,6,7]
-num_gpus = len(mx.test_utils.list_gpus())
+num_gpus = mx.context.num_gpus()
 
 
 if num_gpus > 8 :
-    print("The machine has {} gpus. We will run the test on 8 gpus.".format(num_gpus))
-    print("There is a limit for all PCI-E hardware on creating number of P2P peers. The limit is 8.")
+    logging.warn("The machine has {} gpus. We will run the test on 8 gpus.".format(num_gpus))
+    logging.warn("There is a limit for all PCI-E hardware on creating number of P2P peers. The limit is 8.")
     num_gpus = 8;
 
 gpus = range(1, 1+num_gpus)
 
-class EnvManager:
-    def __init__(self, key, val):
-        self._key = key
-        self._next_val = val
-        self._prev_val = None
-
-    def __enter__(self):
-        try:
-            self._prev_val = os.environ[self._key]
-        except KeyError:
-            self._prev_val = ''
-        os.environ[self._key] = self._next_val
-
-    def __exit__(self, ptype, value, trace):
-        os.environ[self._key] = self._prev_val
-
+@unittest.skipIf(mx.context.num_gpus() < 1, "test_device_pushpull needs at least 1 GPU")
 def test_device_pushpull():
     def check_dense_pushpull(kv_type):
         for shape, key in zip(shapes, keys):
@@ -63,20 +51,16 @@ def test_device_pushpull():
                 for x in range(n_gpus):
                     assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0)
 
-    envs1 = '1'
-    key1 = 'MXNET_KVSTORE_TREE_ARRAY_BOUND'
-    envs2 = ['','1']
-    key2  = 'MXNET_KVSTORE_USETREE'
-    for i in range(2):
-        for val2 in envs2:
-            with EnvManager(key2, val2):
+    kvstore_tree_array_bound = 'MXNET_KVSTORE_TREE_ARRAY_BOUND'
+    kvstore_usetree_values = ['','1']
+    kvstore_usetree  = 'MXNET_KVSTORE_USETREE'
+    for _ in range(2):
+        for x in kvstore_usetree_values:
+            with EnvManager(kvstore_usetree, x):
                 check_dense_pushpull('local')
                 check_dense_pushpull('device')
-
-        os.environ[key1] = envs1
-    os.environ[key1] = ''
-
-    print ("Passed")
+        os.environ[kvstore_tree_array_bound] = '1'
+    del os.environ[kvstore_tree_array_bound]
 
 if __name__ == '__main__':
     test_device_pushpull()
diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py
index 4232a59..8ff8752 100644
--- a/tests/python/gpu/test_kvstore_gpu.py
+++ b/tests/python/gpu/test_kvstore_gpu.py
@@ -21,7 +21,7 @@ import os
 import mxnet as mx
 import numpy as np
 import unittest
-from mxnet.test_utils import assert_almost_equal, default_context
+from mxnet.test_utils import assert_almost_equal, default_context, EnvManager
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
 from common import setup_module, with_seed, teardown
@@ -30,22 +30,6 @@ shape = (4, 4)
 keys = [5, 7, 11]
 str_keys = ['b', 'c', 'd']
 
-class EnvManager:
-    def __init__(self, key, val):
-        self._key = key
-        self._next_val = val
-        self._prev_val = None
-
-    def __enter__(self):
-        try:
-            self._prev_val = os.environ[self._key]
-        except KeyError:
-            self._prev_val = ''
-        os.environ[self._key] = self._next_val
-
-    def __exit__(self, ptype, value, trace):
-        os.environ[self._key] = self._prev_val
-
 def init_kv_with_str(stype='default', kv_type='local'):
     """init kv """
     kv = mx.kv.create(kv_type)