You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/02 19:08:33 UTC

[GitHub] eric-haibin-lin commented on a change in pull request #8732: rsp push and rsp pull for comm device, used in kvstore('device')

eric-haibin-lin commented on a change in pull request #8732: rsp push and rsp pull for comm device, used in kvstore('device')
URL: https://github.com/apache/incubator-mxnet/pull/8732#discussion_r154503677
 
 

 ##########
 File path: tests/python/gpu/test_kvstore_gpu.py
 ##########
 @@ -26,44 +26,53 @@
 str_keys = ['b', 'c', 'd']
 
 
-def init_kv_with_str(stype='default'):
+def init_kv_with_str(stype='default', kv_type='local'):
     """init kv """
-    kv = mx.kv.create()
+    kv = mx.kv.create(kv_type)
     # single
     kv.init('a', mx.nd.zeros(shape, stype=stype))
     # list
     kv.init(str_keys, [mx.nd.zeros(shape=shape, stype=stype)] * len(keys))
     return kv
 
 
-def test_row_sparse_pull():
-    kv = init_kv_with_str('row_sparse')
-    kv.init('e', mx.nd.ones(shape).tostype('row_sparse'))
+def test_rsp_push_pull():
+    def check_rsp_push_pull(kv_type, is_push_cpu=True):
+        kv = init_kv_with_str('row_sparse', kv_type)
+        kv.init('e', mx.nd.ones(shape).tostype('row_sparse'))
+        push_ctxs = [mx.cpu(i) if is_push_cpu else mx.gpu(i) for i in range(2)]
+        kv.push('e', [mx.nd.ones(shape, ctx=context).tostype('row_sparse') for context in push_ctxs])
 
-    def check_row_sparse_pull(kv, count, ctx=default_context()):
-        num_rows = shape[0]
-        vals = []
-        row_ids = []
-        all_row_ids = np.arange(num_rows)
-        for i in range(count):
-            vals.append(mx.nd.zeros(shape, ctx=ctx).tostype('row_sparse'))
-            row_id = np.random.randint(num_rows, size=num_rows)
-            row_ids.append(mx.nd.array(row_id, dtype='int64'))
-        row_ids_to_pull = row_ids[0] if len(row_ids) == 1 else row_ids
-        vals_to_pull = vals[0] if len(vals) == 1 else vals
+        def check_rsp_pull(kv, count, ctxs):
+            num_rows = shape[0]
+            vals = []
+            row_ids = []
+            all_row_ids = np.arange(num_rows)
+            for i in range(count):
+                vals.append(mx.nd.zeros(shape, ctx=ctxs[i]).tostype('row_sparse'))
+                row_id = np.random.randint(num_rows, size=num_rows)
+                row_ids.append(mx.nd.array(row_id, dtype='int64'))
+            row_ids_to_pull = row_ids[0] if len(row_ids) == 1 else row_ids
+            vals_to_pull = vals[0] if len(vals) == 1 else vals
 
-        kv.row_sparse_pull('e', out=vals_to_pull, row_ids=row_ids_to_pull)
-        for val, row_id in zip(vals, row_ids):
-            retained = val.asnumpy()
-            excluded_row_ids = np.setdiff1d(all_row_ids, row_id.asnumpy())
-            for row in range(num_rows):
-                expected_val = np.zeros_like(retained[row])
-                expected_val += 0 if row in excluded_row_ids else 1
-                assert_almost_equal(retained[row], expected_val)
+            kv.row_sparse_pull('e', out=vals_to_pull, row_ids=row_ids_to_pull)
+            for val, row_id in zip(vals, row_ids):
+                retained = val.asnumpy()
+                excluded_row_ids = np.setdiff1d(all_row_ids, row_id.asnumpy())
+                for row in range(num_rows):
+                    expected_val = np.zeros_like(retained[row])
+                    expected_val += 0 if row in excluded_row_ids else 2
+                    assert_almost_equal(retained[row], expected_val)
 
-    check_row_sparse_pull(kv, 1, mx.gpu(0))
-    check_row_sparse_pull(kv, 4, mx.gpu(0))
+        check_rsp_pull(kv, 1, [mx.gpu(0)])
+        check_rsp_pull(kv, 1, [mx.cpu(0)])
+        check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)])
+        check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)])
+
+    check_rsp_push_pull('local')
 
 Review comment:
   Do we have the test case where the same row_id is used for rsp_pull?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services