You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/11 20:03:57 UTC

[GitHub] eric-haibin-lin closed pull request #10845: [MXNET-406] support init/pull dense weight, push row_sparse grad in kvstore

eric-haibin-lin closed pull request #10845: [MXNET-406] support init/pull dense weight, push row_sparse grad in kvstore
URL: https://github.com/apache/incubator-mxnet/pull/10845
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index 70de79b4610..a5d6a1dabef 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -112,41 +112,51 @@ class CommCPU : public Comm {
 
   void Init(int key, const NDArrayStorageType stype, const TShape& shape,
             int type = mshadow::kFloat32) override {
-    if (stype == kDefaultStorage) {
-      merge_buf_[key].merged = NDArray(shape, pinned_ctx_, false, type);
-    } else {
-      merge_buf_[key].merged = NDArray(stype, shape, pinned_ctx_, true, type);
-    }
+    // Delayed allocation - the dense merged buffer might not be used at all if push()
+    // only sees sparse arrays
+    bool delay_alloc = true;
+    merge_buf_[key].merged = NDArray(shape, pinned_ctx_, delay_alloc, type);
   }
 
   const NDArray& Reduce(int key, const std::vector<NDArray>& src,
                         int priority) override {
     auto& buf = merge_buf_[key];
+    const auto stype = src[0].storage_type();
     // avoid extra copy for single device, but it may bring problems for
     // abnormal usage of kvstore
     if (src.size() == 1) {
-      if (src[0].storage_type() == kDefaultStorage) {
+      if (stype == kDefaultStorage) {
         return src[0];
-      } else {  // if sparse and only one GPU, always update weight on CPU
-        CopyFromTo(src[0], &buf.merged, priority);
-        return buf.merged;
+      } else {
+        // With 'local' kvstore, we could store the weight on CPU while compute
+        // the gradient on GPU when the weight is extremely large.
+        // To avoiding copying the weight to the same context of the gradient,
+        // we always copy the gradient to merged buf.
+        NDArray& merged = buf.merged_buf(stype);
+        CopyFromTo(src[0], &merged, priority);
+        return merged;
       }
     }
 
-    if (buf.merged.storage_type() == kDefaultStorage) {
+    NDArray& buf_merged = buf.merged_buf(stype);
+    // normal dense reduce
+    if (stype == kDefaultStorage) {
       std::vector<Engine::VarHandle> const_vars(src.size() - 1);
       std::vector<NDArray> reduce(src.size());
-      CopyFromTo(src[0], &buf.merged, priority);
-      reduce[0] = buf.merged;
+      CopyFromTo(src[0], &buf_merged, priority);
+      reduce[0] = buf_merged;
 
       if (buf.copy_buf.empty()) {
         buf.copy_buf.resize(src.size()-1);
         for (size_t j = 0; j < src.size() - 1; ++j) {
-          // allocate NDArray based on storage type
+          // allocate copy buffer
           buf.copy_buf[j] = NDArray(
             src[0].shape(), pinned_ctx_, false, src[0].dtype());
         }
       }
+      CHECK(stype == buf.copy_buf[0].storage_type())
+           << "Storage type mismatch detected. " << stype << "(src) vs. "
+           << buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
       for (size_t i = 1; i < src.size(); ++i) {
         CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority);
         reduce[i] = buf.copy_buf[i-1];
@@ -161,7 +171,7 @@ class CommCPU : public Comm {
         FnProperty::kCPUPrioritized, priority, "KVStoreReduce");
 
     } else {
-      // buf.merged is a sparse ndarray.
+      // sparse reduce
       std::vector<Engine::VarHandle> const_vars(src.size());
       std::vector<NDArray> reduce(src.size());
 
@@ -172,26 +182,28 @@ class CommCPU : public Comm {
             src[0].storage_type(), src[0].shape(), pinned_ctx_, true, src[0].dtype());
         }
       }
+      CHECK(stype == buf.copy_buf[0].storage_type())
+           << "Storage type mismatch detected. " << stype << "(src) vs. "
+           << buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
       for (size_t i = 0; i < src.size(); ++i) {
         CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
         reduce[i] = buf.copy_buf[i];
         const_vars[i] = reduce[i].var();
       }
-      NDArray result = buf.merged;
-      Resource rsc = ResourceManager::Get()->Request(result.ctx(),
+      Resource rsc = ResourceManager::Get()->Request(buf_merged.ctx(),
           ResourceRequest(ResourceRequest::kTempSpace));
       Engine::Get()->PushAsync(
-        [reduce, result, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
-          NDArray out = result;
+        [reduce, buf_merged, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+          NDArray out = buf_merged;
           is_serial_push_?
             ReduceSumCPUExSerial(reduce, &out)
             : mxnet::ndarray::ElementwiseSum(rctx.get_stream<cpu>(), rsc, reduce, &out);
           on_complete();
-        }, Context::CPU(), const_vars, {result.var(), rsc.var},
+        }, Context::CPU(), const_vars, {buf_merged.var(), rsc.var},
         FnProperty::kCPUPrioritized, priority, "KVStoreReduce");
     }
 
-    return buf.merged;
+    return buf_merged;
   }
 
   void Broadcast(int key, const NDArray& src,
@@ -200,10 +212,14 @@ class CommCPU : public Comm {
     if (mask == Context::kCPU) {
       for (auto d : dst) CopyFromTo(src, d, priority);
     } else {
-      // first copy data to cpu, then broadcast
-      auto& buf = merge_buf_[key];
-      CopyFromTo(src, &buf.merged, priority);
-      for (auto d : dst) CopyFromTo(buf.merged, d, priority);
+      // First copy data to pinned_ctx, then broadcast.
+      // Note that kv.init initializes the data on pinned_ctx.
+      // This branch indicates push() with ndarrays on gpus were called,
+      // and the source is copied to gpu ctx.
+      // Also indicates that buffers are already initialized during push().
+      auto& buf = merge_buf_[key].merged_buf(src.storage_type());
+      CopyFromTo(src, &buf, priority);
+      for (auto d : dst) CopyFromTo(buf, d, priority);
     }
   }
 
@@ -228,7 +244,14 @@ class CommCPU : public Comm {
       NDArray retained_cpu = (is_same_ctx && is_diff_var) ? *out :
           NDArray(kRowSparseStorage, src.shape(), src.ctx(), true,
                   src.dtype(), src.aux_types());
-
+      if (!is_diff_var) {
+        common::LogOnce("The output of row_sparse_pull() on key " + std::to_string(key) +
+                        "refers to the same NDArray as the one stored in KVStore."
+                        "Performing row_sparse_pull() with such output is going to change the "
+                        "data stored in KVStore. Incorrect result may be generated "
+                        "next time row_sparse_pull() is called. To avoid such an issue,"
+                        "consider create a new NDArray buffer to store the output.");
+      }
       Engine::Get()->PushAsync(
         [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
           const TBlob& indices = row_id.data();
@@ -392,6 +415,24 @@ class CommCPU : public Comm {
     NDArray merged;
     /// \brief the cpu buffer for gpu data
     std::vector<NDArray> copy_buf;
+    /// \brief the merged buffer for the given storage type
+    inline NDArray& merged_buf(NDArrayStorageType stype) {
+      if (stype == kDefaultStorage) {
+        return merged;
+      }
+      CHECK(stype == kRowSparseStorage) << "unexpected storage type " << stype;
+      // check if sparse_merged is initialized
+      if (sparse_merged.is_none()) {
+        CHECK(!merged.is_none());
+        sparse_merged = NDArray(kRowSparseStorage, merged.shape(), merged.ctx(),
+                                true, merged.dtype());
+      }
+      return sparse_merged;
+    }
+
+   private:
+    /// \brief the sparse merged value
+    NDArray sparse_merged;
   };
   std::unordered_map<int, BufferEntry> merge_buf_;
   size_t bigarray_bound_;
@@ -417,7 +458,7 @@ class CommDevice : public Comm {
 
   void Init(int key, const NDArrayStorageType stype, const TShape& shape,
             int dtype = mshadow::kFloat32) override {
-    sorted_key_attrs_.emplace_back(key, shape, dtype, stype);
+    sorted_key_attrs_.emplace_back(key, shape, dtype);
   }
 
   void InitBuffersAndComm(const std::vector<NDArray>& src) {
@@ -451,10 +492,12 @@ class CommDevice : public Comm {
     auto& buf = merge_buf_[key];
     std::vector<NDArray> reduce(src.size());
 
-    const NDArrayStorageType stype = buf.merged.storage_type();
+    const NDArrayStorageType stype = src[0].storage_type();
+    NDArray& buf_merged = buf.merged_buf(stype);
+    // normal dense reduce
     if (stype == kDefaultStorage) {
-      CopyFromTo(src[0], &(buf.merged), priority);
-      reduce[0] = buf.merged;
+      CopyFromTo(src[0], &buf_merged, priority);
+      reduce[0] = buf_merged;
 
       if (buf.copy_buf.empty()) {
         // TODO(mli) this results in large device memory usage for huge ndarray,
@@ -464,7 +507,7 @@ class CommDevice : public Comm {
         buf.copy_buf.resize(src.size()-1);
         for (size_t i = 0; i < src.size()-1; ++i) {
           buf.copy_buf[i] = NDArray(
-            buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
+            buf_merged.shape(), buf_merged.ctx(), false, buf_merged.dtype());
         }
       }
       for (size_t i = 0; i < src.size()-1; ++i) {
@@ -472,21 +515,24 @@ class CommDevice : public Comm {
         reduce[i+1] = buf.copy_buf[i];
       }
     } else {
+      // sparse reduce
       if (buf.copy_buf.empty()) {
+        // initialize buffer for copying during reduce
         buf.copy_buf.resize(src.size());
         for (size_t j = 0; j < src.size(); ++j) {
-          buf.copy_buf[j] = NDArray(
-            buf.merged.storage_type(), buf.merged.shape(), buf.merged.ctx(),
-            true, buf.merged.dtype());
+          buf.copy_buf[j] = NDArray(stype, src[0].shape(), buf_merged.ctx(), true, src[0].dtype());
         }
       }
+      CHECK(src[0].storage_type() == buf.copy_buf[0].storage_type())
+           << "Storage type mismatch detected. " << src[0].storage_type() << "(src) vs. "
+           << buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
       for (size_t i = 0; i < src.size(); ++i) {
         CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
         reduce[i] = buf.copy_buf[i];
       }
     }
-    ElementwiseSum(reduce, &buf.merged, priority);
-    return buf.merged;
+    ElementwiseSum(reduce, &buf_merged, priority);
+    return buf_merged;
   }
 
   const NDArray& ReduceCompressed(int key, const std::vector<NDArray>& src,
@@ -547,10 +593,10 @@ class CommDevice : public Comm {
         }
       }
     } else {
-      auto& buf = merge_buf_[key];
-      CopyFromTo(src, &buf.merged, priority);
+      auto& buf_merged = merge_buf_[key].merged_buf(src.storage_type());
+      CopyFromTo(src, &buf_merged, priority);
       for (auto d : dst) {
-        CopyFromTo(buf.merged, d, priority);
+        CopyFromTo(buf_merged, d, priority);
       }
     }
   }
@@ -575,6 +621,14 @@ class CommDevice : public Comm {
       NDArray retained_gpu = (is_same_ctx && is_diff_var) ? *out :
           NDArray(kRowSparseStorage, out->shape(), src.ctx(), true,
                   out->dtype(), out->aux_types());
+      if (!is_diff_var) {
+        common::LogOnce("The output of row_sparse_pull() on key " + std::to_string(key) +
+                        "refers to the same NDArray as the one stored in KVStore."
+                        "Performing row_sparse_pull() with such output is going to change the "
+                        "data stored in KVStore. Incorrect result may be generated "
+                        "next time row_sparse_pull() is called. To avoid such an issue,"
+                        "consider create a new NDArray buffer to store the output.");
+      }
 
       Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
           const TBlob& indices = row_id.data();
@@ -647,7 +701,7 @@ class CommDevice : public Comm {
 #endif
   }
 
-  using KeyAttrs = std::tuple<int, TShape, int, NDArrayStorageType>;
+  using KeyAttrs = std::tuple<int, TShape, int>;
   // try to allocate buff on device evenly
   void InitMergeBuffer(const std::vector<Context>& devs) {
     std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), [](
@@ -659,11 +713,11 @@ class CommDevice : public Comm {
     for (auto d : devs) {
       ctx_info[d.dev_id] = std::make_pair(d, 0);
     }
+
     for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
       const int key  = std::get<0>(sorted_key_attrs_[i]);
       const TShape& shape = std::get<1>(sorted_key_attrs_[i]);
       const int type = std::get<2>(sorted_key_attrs_[i]);
-      const NDArrayStorageType stype = std::get<3>(sorted_key_attrs_[i]);
       auto& buf = merge_buf_[key];
       Context ctx;
       size_t min_size = std::numeric_limits<size_t>::max();
@@ -674,11 +728,10 @@ class CommDevice : public Comm {
           min_size = size;
         }
       }
-      if (stype == kDefaultStorage) {
-        buf.merged = NDArray(shape, ctx, false, type);
-      } else {
-        buf.merged = NDArray(stype, shape, ctx, true, type);
-      }
+      // Delayed allocation - as the dense merged buffer might not be used at all if push()
+      // only sees sparse arrays
+      bool delay_alloc = true;
+      buf.merged = NDArray(shape, ctx, delay_alloc, type);
       ctx_info[ctx.dev_id].second += shape.Size();
     }
     inited_ = true;
@@ -687,9 +740,9 @@ class CommDevice : public Comm {
   std::vector<KeyAttrs> sorted_key_attrs_;
   /// \brief temporal space for pushing and pulling
   struct BufferEntry {
-    /// \brief the merged value
+    /// \brief the dense merged value for reduce and broadcast operations
     NDArray merged;
-    /// \brief the gpu buffer
+    /// \brief the gpu buffer for copy during reduce operation
     std::vector<NDArray> copy_buf;
     /// \brief the residual buffer for gradient compression
     std::vector<NDArray> residual;
@@ -697,6 +750,26 @@ class CommDevice : public Comm {
     std::vector<NDArray> compressed_send_buf;
     /// \brief the small buffer for compressed data in receiver
     std::vector<NDArray> compressed_recv_buf;
+
+    /// \brief the merged buffer for the given storage type (could be either dense or row_sparse)
+    inline NDArray& merged_buf(NDArrayStorageType stype) {
+      if (stype == kDefaultStorage) {
+        CHECK(!merged.is_none()) << "unintialized merge buffer detected";
+        return merged;
+      }
+      CHECK(stype == kRowSparseStorage) << "unexpected storage type " << stype;
+      // check if sparse_merged is initialized
+      if (sparse_merged.is_none()) {
+        CHECK(!merged.is_none());
+        sparse_merged = NDArray(kRowSparseStorage, merged.shape(), merged.ctx(),
+                                true, merged.dtype());
+      }
+      return sparse_merged;
+    }
+
+   private:
+    /// \brief the sparse merged value for reduce and rowsparse broadcast operations
+    NDArray sparse_merged;
   };
   std::unordered_map<int, BufferEntry> merge_buf_;
   bool inited_;
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 3383c97f926..2ac6c11a167 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -276,8 +276,8 @@ class KVStoreLocal : public KVStore {
       // invalid, print warning messages once
       if (this->warnings_printed_.find(key) == this->warnings_printed_.end()) {
         LOG(INFO) << "Warning: non-default weights detected during kvstore pull. "
-                  << "This call has been ignored. "
-                  << "Please make sure to use row_sparse_pull with row_ids.";
+                     "This call has been ignored. Please make sure to use"
+                     "kv.row_sparse_pull() or module.prepare() with row_ids.";
         this->warnings_printed_.insert(key);
       }
       return false;
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 82de0949ccc..ada0183924c 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -1297,7 +1297,7 @@ void ElementwiseSum(const std::vector<NDArray> &source, NDArray *out, int priori
       CHECK_EQ(source[i].ctx().dev_mask(), Context::kCPU)
           << "operands context mismatch";
     } else {
-      CHECK(source[i].ctx() == out->ctx())
+      CHECK_EQ(source[i].ctx(), out->ctx())
           << "operands context mismatch";
     }
   }
diff --git a/tests/nightly/test_kvstore.py b/tests/nightly/test_kvstore.py
index a14feac7a3a..c087da1193f 100644
--- a/tests/nightly/test_kvstore.py
+++ b/tests/nightly/test_kvstore.py
@@ -76,7 +76,7 @@ def as_float32(s):
     return np.array(compr), np.array(new_residual).reshape(arr.shape), np.array(decompr).reshape(arr.shape)
 
 ## individual key interface
-def test_kvstore(kv_type):
+def test_kvstore(kv_type, stype):
     print(kv_type)
     kv = mx.kv.create(kv_type)
     kv.set_optimizer(mx.optimizer.create('test', rescale_grad=lr))
@@ -87,7 +87,7 @@ def test_kvstore(kv_type):
     for i in range(nrepeat):
         for j in range(len(keys)):
             kv.push(keys[j], [mx.nd.array(
-                data[i][j][g], mx.gpu(g)) for g in range(nworker)])
+                data[i][j][g], mx.gpu(g)).tostype(stype) for g in range(nworker)])
 
         res = [a + b * lr for a, b in zip(res, [sum(d) for d in data[i]])]
         for j in range(len(keys)):
@@ -211,7 +211,7 @@ def check_compr_random(kv, threshold):
     check_compr_random(kv, threshold)
 
 ## group keys interface
-def test_group_kvstore(kv_type):
+def test_group_kvstore(kv_type, stype):
     print(kv_type)
     kv = mx.kv.create(kv_type)
     kv.set_optimizer(mx.optimizer.create('test', rescale_grad=lr))
@@ -220,7 +220,7 @@ def test_group_kvstore(kv_type):
     out = [[mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)] for s in shapes]
     for i in range(nrepeat):
         kv.push(keys, [[
-            mx.nd.array(data[i][j][g], mx.gpu(g)) for g in range(nworker)]
+            mx.nd.array(data[i][j][g], mx.gpu(g)).tostype(stype) for g in range(nworker)]
                        for j in range(len(keys))])
 
         kv.pull(keys, out=out)
@@ -234,6 +234,7 @@ def test_group_kvstore(kv_type):
     keys = [3, 5, 7]
     # let the last shape exceed MXNET_KVSTORE_BIGARRAY_BOUND
     shapes = [(4, 4), (100, 100), (2000, 2000)]
+    stypes = ['default', 'row_sparse']
 
     gc_init_test_key = 9
 
@@ -241,16 +242,17 @@ def test_group_kvstore(kv_type):
     nworker = 4
     nrepeat = 10
 
-    ## generate data
+    # generate data
     data = [[[np.random.random(s)*2-1 for i in range(nworker)] for s in shapes] for j in range(nrepeat)]
 
-    test_kvstore('local_update_cpu')
-    test_kvstore('local_allreduce_cpu')
-    test_kvstore('local_allreduce_device')
+    for stype in stypes:
+        test_kvstore('local_update_cpu', stype)
+        test_kvstore('local_allreduce_cpu', stype)
+        test_kvstore('local_allreduce_device', stype)
 
-    # compression for local kvstore happens only when reduce is on device
+    ## compression for local kvstore happens only when reduce is on device
     test_compress_kvstore('local_allreduce_device')
-
-    test_group_kvstore('local_update_cpu')
-    test_group_kvstore('local_allreduce_cpu')
-    test_group_kvstore('local_allreduce_device')
+    for stype in stypes:
+        test_group_kvstore('local_update_cpu', stype)
+        test_group_kvstore('local_allreduce_cpu', stype)
+        test_group_kvstore('local_allreduce_device', stype)
diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py
index c56046ca900..44d522ab9a5 100644
--- a/tests/python/unittest/test_kvstore.py
+++ b/tests/python/unittest/test_kvstore.py
@@ -19,8 +19,8 @@
 import mxnet as mx
 import numpy as np
 import unittest
-from mxnet.test_utils import rand_ndarray, assert_almost_equal, assert_exception
-from common import setup_module, with_seed
+from mxnet.test_utils import rand_ndarray, assert_almost_equal
+from common import setup_module, with_seed, assertRaises
 from mxnet.base import py_str, MXNetError
 
 shape = (4, 4)
@@ -54,14 +54,16 @@ def check_diff_to_scalar(A, x):
 @with_seed()
 def test_single_kv_pair():
     """single key-value pair push & pull"""
-    def check_single_kv_pair(kv, key):
-        kv.push(key, mx.nd.ones(shape))
+    def check_single_kv_pair(kv, key, stype):
+        kv.push(key, mx.nd.ones(shape).tostype(stype))
         val = mx.nd.empty(shape)
         kv.pull(key, out=val)
         check_diff_to_scalar(val, 1)
 
-    check_single_kv_pair(init_kv(), 3)
-    check_single_kv_pair(init_kv_with_str(), 'a')
+    stypes = ['default', 'row_sparse']
+    for stype in stypes:
+        check_single_kv_pair(init_kv(), 3, stype)
+        check_single_kv_pair(init_kv_with_str(), 'a', stype)
 
 @with_seed()
 def test_row_sparse_pull():
@@ -107,46 +109,52 @@ def check_init(kv, key):
 @with_seed()
 def test_list_kv_pair():
     """list key-value pair push & pull"""
-    def check_list_kv_pair(kv, key):
-        kv.push(key, [mx.nd.ones(shape)*4] * len(key))
+    def check_list_kv_pair(kv, key, stype):
+        kv.push(key, [mx.nd.ones(shape).tostype(stype)*4] * len(key))
         val = [mx.nd.empty(shape)] * len(key)
         kv.pull(key, out=val)
         for v in val:
             check_diff_to_scalar(v, 4)
 
-    check_list_kv_pair(init_kv(), keys)
-    check_list_kv_pair(init_kv_with_str(), str_keys)
+    stypes = ['default', 'row_sparse']
+    for stype in stypes:
+        check_list_kv_pair(init_kv(), keys, stype)
+        check_list_kv_pair(init_kv_with_str(), str_keys, stype)
 
 
 @with_seed()
 def test_aggregator():
     """aggregate value on muliple devices"""
 
-    def check_aggregator(kv, key, key_list):
+    def check_aggregator(kv, key, key_list, stype):
         # devices
         num_devs = 4
         devs = [mx.Context('cpu', i) for i in range(num_devs)]
 
         # single
-        vals = [mx.nd.ones(shape, d) for d in devs]
+        vals = [mx.nd.ones(shape, d).tostype(stype) for d in devs]
+        outs = [mx.nd.empty(shape, d) for d in devs]
 
         kv.push(key, vals)
-        kv.pull(key, out=vals)
+        kv.pull(key, out=outs)
 
-        for v in vals:
-            check_diff_to_scalar(v, num_devs)
+        for out in outs:
+            check_diff_to_scalar(out, num_devs)
 
         # list
-        vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(key_list)
+        vals = [[mx.nd.ones(shape, d).tostype(stype)*2.0 for d in devs]] * len(key_list)
+        outs = [[mx.nd.empty(shape, d) for d in devs]] * len(key_list)
         kv.push(key_list, vals)
-        kv.pull(key_list, out=vals)
+        kv.pull(key_list, out=outs)
 
-        for vv in vals:
-            for v in vv:
-                check_diff_to_scalar(v, num_devs * 2.0)
+        for out in outs:
+            for o in out:
+                check_diff_to_scalar(o, num_devs * 2.0)
 
-    check_aggregator(init_kv(), 3, keys)
-    check_aggregator(init_kv_with_str(), 'a', str_keys)
+    stypes = ['default', 'row_sparse']
+    for stype in stypes:
+        check_aggregator(init_kv(), 3, keys, stype)
+        check_aggregator(init_kv_with_str(), 'a', str_keys, stype)
 
 
 @with_seed()
@@ -202,43 +210,47 @@ def str_updater(key, recv, local):
     local += recv
 
 @with_seed()
-def test_updater(dev = 'cpu'):
+def test_updater(dev='cpu'):
     """updater"""
 
-    def check_updater(kv, key, key_list):
+    def check_updater(kv, key, key_list, stype):
         # devices
         num_devs = 4
         devs = [mx.Context(dev, i) for i in range(num_devs)]
 
         # single
-        vals = [mx.nd.ones(shape, d) for d in devs]
+        vals = [mx.nd.ones(shape, d).tostype(stype) for d in devs]
+        outs = [mx.nd.empty(shape, d) for d in devs]
 
         kv.push(key, vals)
-        kv.pull(key, out=vals)
+        kv.pull(key, out=outs)
 
-        for v in vals:
-            check_diff_to_scalar(v, num_devs)
+        for out in outs:
+            check_diff_to_scalar(out, num_devs)
 
         # list
-        vals = [[mx.nd.ones(shape, d) for d in devs]] * len(key_list)
+        vals = [[mx.nd.ones(shape, d).tostype(stype) for d in devs]] * len(key_list)
+        outs = [[mx.nd.empty(shape, d) for d in devs]] * len(key_list)
 
         num_push = 4
         for i in range(num_push):
             kv.push(key_list, vals)
 
-        kv.pull(key_list, out=vals)
+        kv.pull(key_list, out=outs)
 
-        for vv in vals:
-            for v in vv:
-                check_diff_to_scalar(v, num_devs * num_push)
+        for out in outs:
+            for o in out:
+                check_diff_to_scalar(o, num_devs * num_push)
 
-    kv = init_kv()
-    kv._set_updater(updater)
-    check_updater(kv, 3, keys)
+    stypes = ['default', 'row_sparse']
+    for stype in stypes:
+        kv = init_kv()
+        kv._set_updater(updater)
+        check_updater(kv, 3, keys, stype)
 
-    str_kv = init_kv_with_str()
-    str_kv._set_updater(str_updater)
-    check_updater(str_kv, 'a', str_keys)
+        str_kv = init_kv_with_str()
+        str_kv._set_updater(str_updater)
+        check_updater(str_kv, 'a', str_keys, stype)
 
 @with_seed()
 def test_get_type():
@@ -263,30 +275,30 @@ def check_ignored_pull_list(kv, key):
 
     def check_invalid_rsp_pull_single(kv, key):
         dns_val = mx.nd.ones(shape) * 2
-        assert_exception(kv.row_sparse_pull, MXNetError,
-                         key, out=dns_val, row_ids=mx.nd.array([1]))
+        assertRaises(MXNetError, kv.row_sparse_pull,
+                     key, out=dns_val, row_ids=mx.nd.array([1]))
 
     def check_invalid_rsp_pull_list(kv, key):
         dns_val = [mx.nd.ones(shape) * 2] * len(key)
-        assert_exception(kv.row_sparse_pull, MXNetError, key, out=dns_val,
-                         row_ids=[mx.nd.array([1])] * len(key))
+        assertRaises(MXNetError, kv.row_sparse_pull, key, out=dns_val,
+                     row_ids=[mx.nd.array([1])] * len(key))
 
     def check_invalid_key_types_single(kv, key):
         dns_val = mx.nd.ones(shape) * 2
         rsp_val = dns_val.tostype('row_sparse')
-        assert_exception(kv.init, MXNetError, key, dns_val)
-        assert_exception(kv.push, MXNetError, key, dns_val)
-        assert_exception(kv.pull, MXNetError, key, dns_val)
-        assert_exception(kv.row_sparse_pull, MXNetError, key, rsp_val,
-                         row_ids=mx.nd.array([1]))
+        assertRaises(MXNetError, kv.init, key, dns_val)
+        assertRaises(MXNetError, kv.push, key, dns_val)
+        assertRaises(MXNetError, kv.pull, key, dns_val)
+        assertRaises(MXNetError, kv.row_sparse_pull, key, rsp_val,
+                     row_ids=mx.nd.array([1]))
 
     def check_invalid_key_types_list(kv, key):
         dns_val = [mx.nd.ones(shape) * 2] * len(key)
         rsp_val = [val.tostype('row_sparse') for val in dns_val]
-        assert_exception(kv.init, MXNetError, key, dns_val)
-        assert_exception(kv.push, MXNetError, key, dns_val)
-        assert_exception(kv.pull, MXNetError, key, dns_val)
-        assert_exception(kv.row_sparse_pull, MXNetError, key, rsp_val,
+        assertRaises(MXNetError, kv.init, key, dns_val)
+        assertRaises(MXNetError, kv.push, key, dns_val)
+        assertRaises(MXNetError, kv.pull, key, dns_val)
+        assertRaises(MXNetError, kv.row_sparse_pull, key, rsp_val,
                          row_ids=[mx.nd.array([1])] * len(key))
 
     int_kv = init_kv()
@@ -309,5 +321,3 @@ def check_invalid_key_types_list(kv, key):
 if __name__ == '__main__':
     import nose
     nose.runmodule()
-
-


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services