You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/12 16:21:13 UTC

[GitHub] marcoabreu closed pull request #13140: Refactor kvstore test

marcoabreu closed pull request #13140: Refactor kvstore test
URL: https://github.com/apache/incubator-mxnet/pull/13140
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 38a2733001f..7ac63c6c53d 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -1998,3 +1998,20 @@ def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='defa
     if compare_states:
         compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
     assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol)
+
+class EnvManager(object):
+    """Environment variable setter and unsetter via with idiom"""
+    def __init__(self, key, val):
+        self._key = key
+        self._next_val = val
+        self._prev_val = None
+
+    def __enter__(self):
+        self._prev_val = os.environ.get(self._key)
+        os.environ[self._key] = self._next_val
+
+    def __exit__(self, ptype, value, trace):
+        if self._prev_val:
+            os.environ[self._key] = self._prev_val
+        else:
+            del os.environ[self._key]
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
index 43f81a22a40..77881ab940b 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
@@ -42,7 +42,7 @@ object Image {
   def imDecode(buf: Array[Byte], flag: Int,
                to_rgb: Boolean,
                out: Option[NDArray]): NDArray = {
-    val nd = NDArray.array(buf.map(_.toFloat), Shape(buf.length))
+    val nd = NDArray.array(buf.map( x => (x & 0xFF).toFloat), Shape(buf.length))
     val byteND = NDArray.api.cast(nd, "uint8")
     val args : ListBuffer[Any] = ListBuffer()
     val map : mutable.Map[String, Any] = mutable.Map()
diff --git a/tests/python/gpu/test_device.py b/tests/python/gpu/test_device.py
index 66772dc86c2..cd8145c3dea 100644
--- a/tests/python/gpu/test_device.py
+++ b/tests/python/gpu/test_device.py
@@ -19,35 +19,23 @@
 import numpy as np
 import unittest
 import os
+import logging
+
+from mxnet.test_utils import EnvManager
 
 shapes = [(10), (100), (1000), (10000), (100000), (2,2), (2,3,4,5,6,7,8)]
 keys = [1,2,3,4,5,6,7]
-num_gpus = len(mx.test_utils.list_gpus())
+num_gpus = mx.context.num_gpus()
 
 
 if num_gpus > 8 :
-    print("The machine has {} gpus. We will run the test on 8 gpus.".format(num_gpus))
-    print("There is a limit for all PCI-E hardware on creating number of P2P peers. The limit is 8.")
+    logging.warn("The machine has {} gpus. We will run the test on 8 gpus.".format(num_gpus))
+    logging.warn("There is a limit for all PCI-E hardware on creating number of P2P peers. The limit is 8.")
     num_gpus = 8;
 
 gpus = range(1, 1+num_gpus)
 
-class EnvManager:
-    def __init__(self, key, val):
-        self._key = key
-        self._next_val = val
-        self._prev_val = None
-
-    def __enter__(self):
-        try:
-            self._prev_val = os.environ[self._key]
-        except KeyError:
-            self._prev_val = ''
-        os.environ[self._key] = self._next_val
-
-    def __exit__(self, ptype, value, trace):
-        os.environ[self._key] = self._prev_val
-
+@unittest.skipIf(mx.context.num_gpus() < 1, "test_device_pushpull needs at least 1 GPU")
 def test_device_pushpull():
     def check_dense_pushpull(kv_type):
         for shape, key in zip(shapes, keys):
@@ -63,20 +51,16 @@ def check_dense_pushpull(kv_type):
                 for x in range(n_gpus):
                     assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0)
 
-    envs1 = '1'
-    key1 = 'MXNET_KVSTORE_TREE_ARRAY_BOUND'
-    envs2 = ['','1']
-    key2  = 'MXNET_KVSTORE_USETREE'
-    for i in range(2):
-        for val2 in envs2:
-            with EnvManager(key2, val2):
+    kvstore_tree_array_bound = 'MXNET_KVSTORE_TREE_ARRAY_BOUND'
+    kvstore_usetree_values = ['','1']
+    kvstore_usetree  = 'MXNET_KVSTORE_USETREE'
+    for _ in range(2):
+        for x in kvstore_usetree_values:
+            with EnvManager(kvstore_usetree, x):
                 check_dense_pushpull('local')
                 check_dense_pushpull('device')
-
-        os.environ[key1] = envs1
-    os.environ[key1] = ''
-
-    print ("Passed")
+        os.environ[kvstore_tree_array_bound] = '1'
+    del os.environ[kvstore_tree_array_bound]
 
 if __name__ == '__main__':
     test_device_pushpull()
diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py
index 4232a590a5d..8ff8752f534 100644
--- a/tests/python/gpu/test_kvstore_gpu.py
+++ b/tests/python/gpu/test_kvstore_gpu.py
@@ -21,7 +21,7 @@
 import mxnet as mx
 import numpy as np
 import unittest
-from mxnet.test_utils import assert_almost_equal, default_context
+from mxnet.test_utils import assert_almost_equal, default_context, EnvManager
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
 from common import setup_module, with_seed, teardown
@@ -30,22 +30,6 @@
 keys = [5, 7, 11]
 str_keys = ['b', 'c', 'd']
 
-class EnvManager:
-    def __init__(self, key, val):
-        self._key = key
-        self._next_val = val
-        self._prev_val = None
-
-    def __enter__(self):
-        try:
-            self._prev_val = os.environ[self._key]
-        except KeyError:
-            self._prev_val = ''
-        os.environ[self._key] = self._next_val
-
-    def __exit__(self, ptype, value, trace):
-        os.environ[self._key] = self._prev_val
-
 def init_kv_with_str(stype='default', kv_type='local'):
     """init kv """
     kv = mx.kv.create(kv_type)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services