You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2018/11/12 16:21:23 UTC
[incubator-mxnet] branch master updated: Refactor kvstore test
(#13140)
This is an automated email from the ASF dual-hosted git repository.
marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new d8d2d6e Refactor kvstore test (#13140)
d8d2d6e is described below
commit d8d2d6ef3d688a465e47f7170c2a11da804c2835
Author: Pedro Larroy <92...@users.noreply.github.com>
AuthorDate: Mon Nov 12 17:21:11 2018 +0100
Refactor kvstore test (#13140)
* Refactor kvstore test
* Fix pylint
* Fix problem with some OSX not handling the cast on imDecode (#13207)
* Fix num_gpus
---
python/mxnet/test_utils.py | 17 +++++++++++++
tests/python/gpu/test_device.py | 46 ++++++++++++------------------------
tests/python/gpu/test_kvstore_gpu.py | 18 +-------------
3 files changed, 33 insertions(+), 48 deletions(-)
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 38a2733..7ac63c6 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -1998,3 +1998,20 @@ def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='defa
if compare_states:
compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
+
+class EnvManager(object):
+ """Environment variable setter and unsetter via with idiom"""
+ def __init__(self, key, val):
+ self._key = key
+ self._next_val = val
+ self._prev_val = None
+
+ def __enter__(self):
+ self._prev_val = os.environ.get(self._key)
+ os.environ[self._key] = self._next_val
+
+ def __exit__(self, ptype, value, trace):
+ if self._prev_val:
+ os.environ[self._key] = self._prev_val
+ else:
+ del os.environ[self._key]
diff --git a/tests/python/gpu/test_device.py b/tests/python/gpu/test_device.py
index 66772dc..cd8145c 100644
--- a/tests/python/gpu/test_device.py
+++ b/tests/python/gpu/test_device.py
@@ -19,35 +19,23 @@ import mxnet as mx
import numpy as np
import unittest
import os
+import logging
+
+from mxnet.test_utils import EnvManager
shapes = [(10), (100), (1000), (10000), (100000), (2,2), (2,3,4,5,6,7,8)]
keys = [1,2,3,4,5,6,7]
-num_gpus = len(mx.test_utils.list_gpus())
+num_gpus = mx.context.num_gpus()
if num_gpus > 8 :
- print("The machine has {} gpus. We will run the test on 8 gpus.".format(num_gpus))
- print("There is a limit for all PCI-E hardware on creating number of P2P peers. The limit is 8.")
+ logging.warn("The machine has {} gpus. We will run the test on 8 gpus.".format(num_gpus))
+ logging.warn("There is a limit for all PCI-E hardware on creating number of P2P peers. The limit is 8.")
num_gpus = 8;
gpus = range(1, 1+num_gpus)
-class EnvManager:
- def __init__(self, key, val):
- self._key = key
- self._next_val = val
- self._prev_val = None
-
- def __enter__(self):
- try:
- self._prev_val = os.environ[self._key]
- except KeyError:
- self._prev_val = ''
- os.environ[self._key] = self._next_val
-
- def __exit__(self, ptype, value, trace):
- os.environ[self._key] = self._prev_val
-
+@unittest.skipIf(mx.context.num_gpus() < 1, "test_device_pushpull needs at least 1 GPU")
def test_device_pushpull():
def check_dense_pushpull(kv_type):
for shape, key in zip(shapes, keys):
@@ -63,20 +51,16 @@ def test_device_pushpull():
for x in range(n_gpus):
assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0)
- envs1 = '1'
- key1 = 'MXNET_KVSTORE_TREE_ARRAY_BOUND'
- envs2 = ['','1']
- key2 = 'MXNET_KVSTORE_USETREE'
- for i in range(2):
- for val2 in envs2:
- with EnvManager(key2, val2):
+ kvstore_tree_array_bound = 'MXNET_KVSTORE_TREE_ARRAY_BOUND'
+ kvstore_usetree_values = ['','1']
+ kvstore_usetree = 'MXNET_KVSTORE_USETREE'
+ for _ in range(2):
+ for x in kvstore_usetree_values:
+ with EnvManager(kvstore_usetree, x):
check_dense_pushpull('local')
check_dense_pushpull('device')
-
- os.environ[key1] = envs1
- os.environ[key1] = ''
-
- print ("Passed")
+ os.environ[kvstore_tree_array_bound] = '1'
+ del os.environ[kvstore_tree_array_bound]
if __name__ == '__main__':
test_device_pushpull()
diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py
index 4232a59..8ff8752 100644
--- a/tests/python/gpu/test_kvstore_gpu.py
+++ b/tests/python/gpu/test_kvstore_gpu.py
@@ -21,7 +21,7 @@ import os
import mxnet as mx
import numpy as np
import unittest
-from mxnet.test_utils import assert_almost_equal, default_context
+from mxnet.test_utils import assert_almost_equal, default_context, EnvManager
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import setup_module, with_seed, teardown
@@ -30,22 +30,6 @@ shape = (4, 4)
keys = [5, 7, 11]
str_keys = ['b', 'c', 'd']
-class EnvManager:
- def __init__(self, key, val):
- self._key = key
- self._next_val = val
- self._prev_val = None
-
- def __enter__(self):
- try:
- self._prev_val = os.environ[self._key]
- except KeyError:
- self._prev_val = ''
- os.environ[self._key] = self._next_val
-
- def __exit__(self, ptype, value, trace):
- os.environ[self._key] = self._prev_val
-
def init_kv_with_str(stype='default', kv_type='local'):
"""init kv """
kv = mx.kv.create(kv_type)