You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/30 18:02:03 UTC

[incubator-mxnet] branch master updated: Add string interface to updater to make it consistent with kvstore (#7585)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 470b564  Add string interface to updater to make it consistent with kvstore (#7585)
470b564 is described below

commit 470b56437290b33fef51b067c96fdce08cefa584
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Wed Aug 30 11:02:00 2017 -0700

    Add string interface to updater to make it consistent with kvstore (#7585)
    
    * str kv updater draft
    
    * backward compatibility for other languages
    
    * add capi MXKVStoreSetUpdaterEx
    
    * fix nightly testkvstore test
    
    * convert c_char_p/byte to str for python3
    
    * add key type restriction to backend
    
    * add test to check mixed key types
    
    * remvoe nested catch throw"
---
 include/mxnet/c_api.h                 |  48 +++++++++++-
 include/mxnet/kvstore.h               |  24 +++++-
 python/mxnet/kvstore.py               | 100 ++++++++++++++++--------
 python/mxnet/optimizer.py             |   4 +
 src/c_api/c_api.cc                    |  57 ++++++++++++--
 src/kvstore/kvstore_local.h           | 140 +++++++++++++++++++++++++---------
 tests/nightly/test_kvstore.py         |   4 +-
 tests/python/unittest/test_kvstore.py |  96 +++++++++++++++--------
 8 files changed, 359 insertions(+), 114 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index bba6190..ef9d31e 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1602,7 +1602,7 @@ MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle,
                               int priority);
 
 /*!
- * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string.
+ * \brief pull a list of (key, value) pairs from the kvstore, where each key is an integer.
  *        The NDArray pulled back will be in row_sparse storage with only the specified
  *        row_ids present based row_ids (others rows are zeros).
  * \param handle handle to the kvstore
@@ -1615,10 +1615,28 @@ MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle,
  */
 MXNET_DLL int MXKVStorePullRowSparse(KVStoreHandle handle,
                                      mx_uint num,
-                                     const char** keys,
+                                     const int* keys,
                                      NDArrayHandle* vals,
                                      const NDArrayHandle* row_ids,
                                      int priority);
+/*!
+ * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string.
+ *        The NDArray pulled back will be in row_sparse storage with only the specified
+ *        row_ids present based row_ids (others rows are zeros).
+ * \param handle handle to the kvstore
+ * \param num the number of key-value pairs
+ * \param keys the list of keys
+ * \param vals the list of values
+ * \param row_ids the list of row_id NDArrays
+ * \param priority the priority of the action
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXKVStorePullRowSparseEx(KVStoreHandle handle,
+                                       mx_uint num,
+                                       const char** keys,
+                                       NDArrayHandle* vals,
+                                       const NDArrayHandle* row_ids,
+                                       int priority);
 
 /*!
  * \brief user-defined updater for the kvstore
@@ -1633,7 +1651,19 @@ typedef void (MXKVStoreUpdater)(int key,
                                 NDArrayHandle local,
                                 void *handle);
 /*!
- * \brief register an push updater
+ * \brief user-defined updater for the kvstore with string keys
+ * It's this updater's responsibility to delete \a recv and \a local
+ * \param the key
+ * \param recv the pushed value on this key
+ * \param local the value stored on local on this key
+ * \param handle The additional handle to the updater
+ */
+typedef void (MXKVStoreStrUpdater)(const char* key,
+                                   NDArrayHandle recv,
+                                   NDArrayHandle local,
+                                   void *handle);
+/*!
+ * \brief register a push updater
  * \param handle handle to the KVStore
  * \param updater udpater function
  * \param updater_handle The additional handle used to invoke the updater
@@ -1643,6 +1673,18 @@ MXNET_DLL int MXKVStoreSetUpdater(KVStoreHandle handle,
                                   MXKVStoreUpdater updater,
                                   void *updater_handle);
 /*!
+ * \brief register a push updater with int keys and one with string keys
+ * \param handle handle to the KVStore
+ * \param updater updater function with int keys
+ * \param str_updater updater function with string keys
+ * \param updater_handle The additional handle used to invoke the updater
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXKVStoreSetUpdaterEx(KVStoreHandle handle,
+                                    MXKVStoreUpdater updater,
+                                    MXKVStoreStrUpdater str_updater,
+                                    void *updater_handle);
+/*!
  * \brief get the type of the kvstore
  * \param handle handle to the KVStore
  * \param type a string type
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index 9ea63b4..bca88a5 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -202,6 +202,10 @@ class KVStore {
    * \brief the prototype of user-defined updater
    */
   typedef std::function<void(int, const NDArray&, NDArray*)> Updater;
+  /**
+   * \brief the prototype of user-defined updater with string keys
+   */
+  typedef std::function<void(const std::string&, const NDArray&, NDArray*)> StrUpdater;
   /*!
    * \brief set an updater
    *
@@ -215,6 +219,19 @@ class KVStore {
     CHECK(updater) << "invalid updater";
     updater_ = updater;
   }
+  /*!
+   * \brief set an updater with string keys
+   *
+   * Given a string key, assume \a x is the received (pushed) value and \a y is the
+   * value stored on the store node. The store updates \a y by `h(x, &y)`. The
+   * default \a h is ASSIGN, namely `*y = x`.
+   *
+   * \param updater user-defined string updater, default is assign
+   */
+  virtual void set_updater(const StrUpdater& updater) {
+    CHECK(updater) << "invalid updater";
+    str_updater_ = updater;
+  }
 
   /******************************************************
    * the following are used for multi-machines.
@@ -356,11 +373,16 @@ class KVStore {
 
  protected:
   /**
-   * \brief the user-defined  updater
+   * \brief the user-defined updater
    */
   Updater updater_;
 
   /**
+   * \brief the user-defined updater with string keys
+   */
+  StrUpdater str_updater_;
+
+  /**
    * \brief the kvstore type
    */
   std::string type_;
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index 2af70e3..bc034c5 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -29,26 +29,39 @@ from .base import NDArrayHandle, KVStoreHandle
 from . import optimizer as opt
 
 def _ctype_key_value(keys, vals):
+    """
+    Returns ctype arrays for the key-value args, and the whether string keys are used.
+    For internal use only.
+    """
     if isinstance(keys, (tuple, list)):
         assert(len(keys) == len(vals))
         c_keys = []
         c_vals = []
+        use_str_keys = None
         for key, val in zip(keys, vals):
-            c_key_i, c_val_i = _ctype_key_value(key, val)
+            c_key_i, c_val_i, str_keys_i = _ctype_key_value(key, val)
             c_keys += c_key_i
             c_vals += c_val_i
-        return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals))
-    names = []
-    keys = str(keys)
+            use_str_keys = str_keys_i if use_str_keys is None else use_str_keys
+            assert(use_str_keys == str_keys_i), "inconsistent types of keys detected."
+        c_keys_arr = c_array(ctypes.c_char_p, c_keys) if use_str_keys \
+                     else c_array(ctypes.c_int, c_keys)
+        c_vals_arr = c_array(NDArrayHandle, c_vals)
+        return (c_keys_arr, c_vals_arr, use_str_keys)
+
+    assert(isinstance(keys, (int,) + string_types)), \
+           "unexpected type for keys: " + str(type(keys))
+    use_str_keys = isinstance(keys, string_types)
     if isinstance(vals, NDArray):
-        names.append(c_str(keys))
-        return (c_array(ctypes.c_char_p, names),
-                c_array(NDArrayHandle, [vals.handle]))
+        c_keys = c_array(ctypes.c_char_p, [c_str(keys)]) if use_str_keys \
+                 else c_array(ctypes.c_int, [keys])
+        return (c_keys, c_array(NDArrayHandle, [vals.handle]), use_str_keys)
     else:
         for value in vals:
             assert(isinstance(value, NDArray))
-        return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)),
-                c_array(NDArrayHandle, [value.handle for value in vals]))
+        c_keys = c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)) if use_str_keys \
+                 else c_array(ctypes.c_int, [keys] * len(vals))
+        return (c_keys, c_array(NDArrayHandle, [value.handle for value in vals]), use_str_keys)
 
 def _updater_wrapper(updater):
     """A wrapper for the user-defined handle."""
@@ -74,6 +87,7 @@ class KVStore(object):
         self.handle = handle
         self._updater = None
         self._updater_func = None
+        self._str_updater_func = None
 
     def __del__(self):
         check_call(_LIB.MXKVStoreFree(self.handle))
@@ -88,7 +102,7 @@ class KVStore(object):
 
         Parameters
         ----------
-        key : str or sequence of str
+        key : str, int, or sequence of str or int
             The keys.
         value : NDArray or sequence of NDArray
             Values corresponding to the keys.
@@ -106,11 +120,14 @@ class KVStore(object):
         [ 2.  2.  2.]]
 
         >>> # init a list of key-value pairs
-        >>> keys = ['5', '7', '9']
+        >>> keys = [5, 7, 9]
         >>> kv.init(keys, [mx.nd.ones(shape)]*len(keys))
         """
-        ckeys, cvals = _ctype_key_value(key, value)
-        check_call(_LIB.MXKVStoreInitEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals))
+        ckeys, cvals, use_str_keys = _ctype_key_value(key, value)
+        if use_str_keys:
+            check_call(_LIB.MXKVStoreInitEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals))
+        else:
+            check_call(_LIB.MXKVStoreInit(self.handle, mx_uint(len(ckeys)), ckeys, cvals))
 
     def push(self, key, value, priority=0):
         """ Pushes a single or a sequence of key-value pairs into the store.
@@ -123,7 +140,7 @@ class KVStore(object):
 
         Parameters
         ----------
-        key : str or list of str
+        key : str, int, or sequence of str or int
             Keys.
 
         value : NDArray or list of NDArray or list of list of NDArray
@@ -154,6 +171,7 @@ class KVStore(object):
 
         >>> # push a list of keys.
         >>> # single device
+        >>> keys = [4, 5, 6]
         >>> kv.push(keys, [mx.nd.ones(shape)]*len(keys))
         >>> b = [mx.nd.zeros(shape)]*len(keys)
         >>> kv.pull(keys, out=b)
@@ -162,6 +180,7 @@ class KVStore(object):
         [ 1.  1.  1.]]
 
         >>> # multiple devices:
+        >>> keys = ['7', '8', '9']
         >>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys)
         >>> kv.push(keys, b)
         >>> kv.pull(keys, out=b)
@@ -169,10 +188,13 @@ class KVStore(object):
         [[ 4.  4.  4.]
         [ 4.  4.  4.]]
         """
-        ckeys, cvals = _ctype_key_value(key, value)
-        check_call(_LIB.MXKVStorePushEx(
-            self.handle, mx_uint(len(ckeys)), ckeys, cvals,
-            ctypes.c_int(priority)))
+        ckeys, cvals, use_str_keys = _ctype_key_value(key, value)
+        if use_str_keys:
+            check_call(_LIB.MXKVStorePushEx(
+                self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
+        else:
+            check_call(_LIB.MXKVStorePush(
+                self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
 
 
     def pull(self, key, out=None, priority=0):
@@ -191,7 +213,7 @@ class KVStore(object):
 
         Parameters
         ----------
-        key : int or list of int
+        key : str, int, or sequence of str or int
             Keys.
 
         out: NDArray or list of NDArray or list of list of NDArray
@@ -220,13 +242,14 @@ class KVStore(object):
 
         >>> # pull a list of key-value pairs.
         >>> # On single device
-        >>> keys = ['5', '7', '9']
+        >>> keys = [5, 7, 9]
         >>> b = [mx.nd.zeros(shape)]*len(keys)
         >>> kv.pull(keys, out=b)
         >>> print b[1].asnumpy()
         [[ 2.  2.  2.]
         [ 2.  2.  2.]]
         >>> # On multiple devices
+        >>> keys = ['6', '8', '10']
         >>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys)
         >>> kv.pull(keys, out=b)
         >>> print b[1][1].asnumpy()
@@ -234,10 +257,13 @@ class KVStore(object):
         [ 2.  2.  2.]]
         """
         assert(out is not None)
-        ckeys, cvals = _ctype_key_value(key, out)
-        check_call(_LIB.MXKVStorePullEx(
-            self.handle, mx_uint(len(ckeys)), ckeys, cvals,
-            ctypes.c_int(priority)))
+        ckeys, cvals, use_str_keys = _ctype_key_value(key, out)
+        if use_str_keys:
+            check_call(_LIB.MXKVStorePullEx(
+                self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
+        else:
+            check_call(_LIB.MXKVStorePull(
+                self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
 
     def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
         """ Pulls a single row_sparse value or a sequence of row_sparse values from the store
@@ -250,7 +276,7 @@ class KVStore(object):
 
         Parameters
         ----------
-        key : str or list of str
+        key : str, int, or sequence of str or int
             Keys.
 
         out: NDArray or list of NDArray or list of list of NDArray
@@ -291,12 +317,16 @@ class KVStore(object):
         """
         assert(out is not None)
         assert(row_ids is not None)
-        ckeys, cvals = _ctype_key_value(key, out)
-        _, crow_ids = _ctype_key_value(key, row_ids)
-        assert(len(crow_ids) == len(cvals)), "number of row_ids doesn't match number of values"
-
-        check_call(_LIB.MXKVStorePullRowSparse(
-            self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority)))
+        ckeys, cvals, use_str_keys = _ctype_key_value(key, out)
+        _, crow_ids, _ = _ctype_key_value(key, row_ids)
+        assert(len(crow_ids) == len(cvals)), \
+               "the number of row_ids doesn't match the number of values"
+        if use_str_keys:
+            check_call(_LIB.MXKVStorePullRowSparseEx(
+                self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority)))
+        else:
+            check_call(_LIB.MXKVStorePullRowSparse(
+                self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority)))
 
 
     def set_optimizer(self, optimizer):
@@ -436,10 +466,16 @@ class KVStore(object):
         [ 6.  6.  6.]]
         """
         self._updater = updater
+        # set updater with int keys
         _updater_proto = ctypes.CFUNCTYPE(
             None, ctypes.c_int, NDArrayHandle, NDArrayHandle, ctypes.c_void_p)
         self._updater_func = _updater_proto(_updater_wrapper(updater))
-        check_call(_LIB.MXKVStoreSetUpdater(self.handle, self._updater_func, None))
+        # set updater with str keys
+        _str_updater_proto = ctypes.CFUNCTYPE(
+            None, ctypes.c_char_p, NDArrayHandle, NDArrayHandle, ctypes.c_void_p)
+        self._str_updater_func = _str_updater_proto(_updater_wrapper(updater))
+        check_call(_LIB.MXKVStoreSetUpdaterEx(self.handle, self._updater_func,
+                                              self._str_updater_func, None))
 
 
     def _barrier(self):
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index e7e283f..099d2b7 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -21,6 +21,7 @@ import pickle
 import logging
 import warnings
 import numpy
+from .base import py_str
 from .ndarray import (NDArray, zeros, clip, sqrt, sign, array, maximum, abs as NDabs)
 from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
                       mp_sgd_update, mp_sgd_mom_update)
@@ -949,6 +950,9 @@ class Updater(object):
 
     def __call__(self, index, grad, weight):
         """Updates weight given gradient and index."""
+        # convert ctypes.char_p.value back to python str if needed
+        if isinstance(index, bytes):
+            index = py_str(index)
         if index not in self.states:
             self.states[index] = self.optimizer.create_state(index, weight)
             self.states_synced[index] = True
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 088e208..5171e27 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -829,12 +829,12 @@ int MXKVStorePullEx(KVStoreHandle handle,
 
 int MXKVStorePullRowSparse(KVStoreHandle handle,
                            mx_uint num,
-                           const char** keys,
+                           const int* keys,
                            NDArrayHandle* vals,
                            const NDArrayHandle* row_ids,
                            int priority) {
   API_BEGIN();
-  std::vector<std::string> v_keys(num);
+  std::vector<int> v_keys(num);
   std::vector<std::pair<NDArray*, NDArray>> v_val_rowids(num);
   for (mx_uint i = 0; i < num; ++i) {
     v_keys[i] = keys[i];
@@ -845,10 +845,27 @@ int MXKVStorePullRowSparse(KVStoreHandle handle,
   API_END();
 }
 
-int MXKVStoreSetUpdater(KVStoreHandle handle,
-                        MXKVStoreUpdater updater,
-                        void* updater_handle) {
+int MXKVStorePullRowSparseEx(KVStoreHandle handle,
+                             mx_uint num,
+                             const char** keys,
+                             NDArrayHandle* vals,
+                             const NDArrayHandle* row_ids,
+                             int priority) {
   API_BEGIN();
+  std::vector<std::string> v_keys(num);
+  std::vector<std::pair<NDArray*, NDArray>> v_val_rowids(num);
+  for (mx_uint i = 0; i < num; ++i) {
+    v_keys[i] = keys[i];
+    v_val_rowids[i] = std::make_pair(static_cast<NDArray*>(vals[i]),
+                                     *static_cast<NDArray*>(row_ids[i]));
+  }
+  static_cast<KVStore*>(handle)->PullRowSparse(v_keys, v_val_rowids, priority);
+  API_END();
+}
+
+void MXKVStoreSetUpdaterImpl(KVStoreHandle handle,
+                             MXKVStoreUpdater updater,
+                             void* updater_handle) {
   MXKVStoreUpdater * updater_temp = updater;
   void* updater_handle_temp = updater_handle;
   std::function<void(int, const NDArray&, NDArray*)> updt
@@ -860,6 +877,36 @@ int MXKVStoreSetUpdater(KVStoreHandle handle,
     updater_temp(key, recv_copy, local_copy, updater_handle_temp);
   };
   static_cast<KVStore*>(handle)->set_updater(updt);
+}
+
+int MXKVStoreSetUpdater(KVStoreHandle handle,
+                        MXKVStoreUpdater updater,
+                        void* updater_handle) {
+  API_BEGIN();
+  MXKVStoreSetUpdaterImpl(handle, updater, updater_handle);
+  API_END();
+}
+
+int MXKVStoreSetUpdaterEx(KVStoreHandle handle,
+                          MXKVStoreUpdater updater,
+                          MXKVStoreStrUpdater str_updater,
+                          void* updater_handle) {
+  API_BEGIN();
+  // set updater with int keys
+  MXKVStoreSetUpdaterImpl(handle, updater, updater_handle);
+  // set updater with string keys
+  MXKVStoreStrUpdater * updater_temp = str_updater;
+  void* updater_handle_temp = updater_handle;
+  std::function<void(const std::string&, const NDArray&, NDArray*)> updt
+  = [updater_temp, updater_handle_temp]
+    (const std::string& key, const NDArray& recv, NDArray* local) {
+    NDArray* recv_copy = new NDArray();
+    *recv_copy = recv;
+    NDArray* local_copy = new NDArray();
+    *local_copy = *local;
+    updater_temp(key.c_str(), recv_copy, local_copy, updater_handle_temp);
+  };
+  static_cast<KVStore*>(handle)->set_updater(updt);
   API_END();
 }
 
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 11d4b64..e05819b 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -36,6 +36,13 @@
 
 namespace mxnet {
 namespace kvstore {
+
+enum KeyType {
+  kUndefinedKey = -1,
+  kStringKey,
+  kIntKey
+};
+
 /**
  * \brief store data in local machine
  */
@@ -59,16 +66,13 @@ class KVStoreLocal : public KVStore {
 
   void Init(const std::vector<int>& keys,
             const std::vector<NDArray>& values) override {
-    for (size_t i = 0; i < keys.size(); ++i) {
-      CHECK(local_.find(keys[i]) == local_.end())
-          << "duplicate init of key " << keys[i];
-      local_[keys[i]] = values[i].Copy(pinned_ctx_);
-      comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
-    }
+    SetKeyType(kIntKey);
+    Init_(keys, values);
   }
 
   void Init(const std::vector<std::string>& str_keys,
             const std::vector<NDArray>& values) override {
+    SetKeyType(kStringKey);
     std::vector<int> keys(str_keys.size());
     for (size_t i = 0; i < str_keys.size(); ++i) {
       auto &str_key = str_keys[i];
@@ -76,18 +80,78 @@ class KVStoreLocal : public KVStore {
             << "duplicate init of key " << str_key;
       auto key = next_str_key_++;
       str_key_dict_[str_key] = key;
+      // record reverse mapping from int to string
+      reverse_str_key_dict_[key] = str_key;
       keys[i] = key;
     }
-    Init(keys, values);
+    Init_(keys, values);
   }
 
   void Push(const std::vector<int>& keys,
             const std::vector<NDArray>& values,
             int priority) override {
+    SetKeyType(kIntKey);
+    Push_(keys, values, priority);
+  }
+
+  void Pull(const std::vector<int>& keys,
+            const std::vector<NDArray*>& values,
+            int priority) override {
+    SetKeyType(kIntKey);
+    Pull_(keys, values, priority);
+  }
+
+  void PullRowSparse(const std::vector<int>& keys,
+                     const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
+                     int priority = 0) override {
+    SetKeyType(kIntKey);
+    PullRowSparse_(keys, val_rowids, priority);
+  }
+
+  void Push(const std::vector<std::string>& str_keys,
+            const std::vector<NDArray>& values,
+            int priority) override {
+    SetKeyType(kStringKey);
+    std::vector<int> keys(str_keys.size());
+    LookupKeys(str_keys, &keys);
+    Push_(keys, values, priority);
+  }
+
+  void Pull(const std::vector<std::string>& str_keys,
+            const std::vector<NDArray*>& values,
+            int priority) override {
+    SetKeyType(kStringKey);
+    std::vector<int> keys(str_keys.size());
+    LookupKeys(str_keys, &keys);
+    Pull_(keys, values, priority);
+  }
+
+  void PullRowSparse(const std::vector<std::string>& str_keys,
+                     const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
+                     const int priority = 0) override {
+    SetKeyType(kStringKey);
+    std::vector<int> keys(str_keys.size());
+    LookupKeys(str_keys, &keys);
+    PullRowSparse_(keys, val_rowids, priority);
+  }
+
+ private:
+  void Init_(const std::vector<int>& keys,
+             const std::vector<NDArray>& values) {
+    for (size_t i = 0; i < keys.size(); ++i) {
+      CHECK(local_.find(keys[i]) == local_.end())
+          << "duplicate init of key " << keys[i];
+      local_[keys[i]] = values[i].Copy(pinned_ctx_);
+      comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
+    }
+  }
+
+  void Push_(const std::vector<int>& keys,
+             const std::vector<NDArray>& values,
+             int priority) {
     std::vector<int> uniq_keys;
     std::vector<std::vector<NDArray> > grouped_vals;
     GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals);
-
     for (size_t i = 0; i < uniq_keys.size(); ++i) {
       int key = uniq_keys[i];
       const NDArray& merged = comm_->Reduce(key, grouped_vals[i], priority);
@@ -99,7 +163,18 @@ class KVStoreLocal : public KVStore {
             local.ctx().dev_mask() == cpu::kDevMask) {
           local = local.Copy(merged.ctx());
         }
-        updater_(key, merged,  &local);
+        // call the updater with string keys
+        // if string keys are used and str_updater_ is available
+        // otherwise fallback to updater_ which uses int key interface
+        if (key_type_ == kStringKey && str_updater_ != nullptr) {
+          // TODO(haibin) CHECK(str_updater_ != nullptr) if use_str_key
+          // after all language bindings picks up string interface changes
+          const std::string &str_key = reverse_str_key_dict_[key];
+          // TODO(haibin) avoid reverse key lookup if use_str_key
+          str_updater_(str_key, merged,  &local);
+        } else {
+          updater_(key, merged,  &local);
+        }
       } else {
         if (merged.storage_type() != local.storage_type()) {
           local = merged.Copy(local.ctx());
@@ -110,9 +185,9 @@ class KVStoreLocal : public KVStore {
     }
   }
 
-  void Pull(const std::vector<int>& keys,
-            const std::vector<NDArray*>& values,
-            int priority) override {
+  void Pull_(const std::vector<int>& keys,
+             const std::vector<NDArray*>& values,
+             int priority) {
     std::vector<int> uniq_keys;
     std::vector<std::vector<NDArray*> > grouped_vals;
     GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals);
@@ -125,9 +200,9 @@ class KVStoreLocal : public KVStore {
     }
   }
 
-  void PullRowSparse(const std::vector<int>& keys,
-                     const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
-                     int priority = 0) override {
+  void PullRowSparse_(const std::vector<int>& keys,
+                      const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
+                      int priority = 0) {
     std::vector<int> uniq_keys;
     std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids;
     GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids);
@@ -149,31 +224,16 @@ class KVStoreLocal : public KVStore {
     }
   }
 
-  void Push(const std::vector<std::string>& str_keys,
-            const std::vector<NDArray>& values,
-            int priority) override {
-    std::vector<int> keys(str_keys.size());
-    LookupKeys(str_keys, &keys);
-    Push(keys, values, priority);
-  }
-
-  void Pull(const std::vector<std::string>& str_keys,
-            const std::vector<NDArray*>& values,
-            int priority) override {
-    std::vector<int> keys(str_keys.size());
-    LookupKeys(str_keys, &keys);
-    Pull(keys, values, priority);
-  }
-
-  void PullRowSparse(const std::vector<std::string>& str_keys,
-                     const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
-                     const int priority = 0) override {
-    std::vector<int> keys(str_keys.size());
-    LookupKeys(str_keys, &keys);
-    PullRowSparse(keys, val_rowids, priority);
+ protected:
+  /**
+   * \brief set the key type of the kvstore if haven't already.
+   * If the key type is already defined, check if it matches the provided key type
+   */
+  void SetKeyType(const KeyType key_type) {
+    if (key_type_ == kUndefinedKey) key_type_ = key_type;
+    CHECK_EQ(key_type_, key_type) << "Mixed key types are not allowed";
   }
 
- protected:
   /**
    * \brief group values on keys for push
    */
@@ -309,10 +369,14 @@ class KVStoreLocal : public KVStore {
   std::unordered_map<int, NDArray> local_;
   /// key mapping for string -> integer
   std::unordered_map<std::string, int> str_key_dict_;
+  /// reverse key mapping for integer -> string
+  std::unordered_map<int, std::string> reverse_str_key_dict_;
   /// the next available integer for string->int key mapping
   int next_str_key_ = 0;
   /// whether printed warning due to mismatch stype in each key
   std::unordered_set<int> warnings_printed_;
+  /// whether int or string is used for keys
+  KeyType key_type_ = kUndefinedKey;
 };
 }  // namespace kvstore
 }  // namespace mxnet
diff --git a/tests/nightly/test_kvstore.py b/tests/nightly/test_kvstore.py
index b39ec89..081bc9c 100644
--- a/tests/nightly/test_kvstore.py
+++ b/tests/nightly/test_kvstore.py
@@ -37,7 +37,7 @@ data = [[[np.random.random(s)*2-1 for i in range(nworker)] for s in shapes] for
 def test_kvstore(kv_type):
     print(kv_type)
     kv = mx.kv.create(kv_type)
-    kv.set_optimizer(mx.optimizer.create('test', lr))
+    kv.set_optimizer(mx.optimizer.create('test', rescale_grad=lr))
     for k, s in zip(keys, shapes):
         kv.init(k, mx.nd.zeros(s))
 
@@ -63,7 +63,7 @@ test_kvstore('local_allreduce_device')
 def test_group_kvstore(kv_type):
     print(kv_type)
     kv = mx.kv.create(kv_type)
-    kv.set_optimizer(mx.optimizer.create('test', lr))
+    kv.set_optimizer(mx.optimizer.create('test', rescale_grad=lr))
     kv.init(keys, [mx.nd.zeros(s) for s in shapes])
     res = [np.zeros(s) for s in shapes]
     out = [[mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)] for s in shapes]
diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py
index a43b98a..20ad2cd 100644
--- a/tests/python/unittest/test_kvstore.py
+++ b/tests/python/unittest/test_kvstore.py
@@ -19,11 +19,19 @@
 import mxnet as mx
 import numpy as np
 from mxnet.test_utils import rand_ndarray, assert_almost_equal
+from mxnet.base import py_str
 
 shape = (4, 4)
 keys = [5, 7, 11]
 str_keys = ['b', 'c', 'd']
 
+def assert_exception(f, *args, **kwargs):
+    try:
+        f(*args, **kwargs)
+        assert(False)
+    except:
+        return
+
 def init_kv(stype='default'):
     """init kv """
     kv = mx.kv.create()
@@ -180,9 +188,16 @@ def test_sparse_aggregator():
         assert_almost_equal(result_sum, expected_sum * num_devs)
 
 def updater(key, recv, local):
-    """use updater: +="""
+    """use updater: += with int keys"""
+    assert(isinstance(key, int))
     local += recv
 
+def str_updater(key, recv, local):
+    """use updater: += with str keys"""
+    if isinstance(key, bytes):
+        key = py_str(key)
+    assert(isinstance(key, str))
+    local += recv
 
 def test_updater(dev = 'cpu'):
     """updater"""
@@ -219,7 +234,7 @@ def test_updater(dev = 'cpu'):
     check_updater(kv, 3, keys)
 
     str_kv = init_kv_with_str()
-    str_kv._set_updater(updater)
+    str_kv._set_updater(str_updater)
     check_updater(str_kv, 'a', str_keys)
 
 def test_get_type():
@@ -228,48 +243,63 @@ def test_get_type():
     assert kv.type == kvtype
 
 def test_invalid_pull():
-    def check_invalid_single_kv_pair(kv, key):
-        dns_val = mx.nd.ones(shape) * 2
+    def check_ignored_pull_single(kv, key):
+        dns_val = (mx.nd.ones(shape) * 2)
         rsp_val = dns_val.tostype('row_sparse')
         kv.pull(key, out=rsp_val)
-        # pull should be ignored with no values updated
         check_diff_to_scalar(rsp_val, 2)
-        try:
-            # row_sparse_pull should be aborted when vals.stype != row_sparse
-            kv.row_sparse_pull(key, out=dns_val, rowids=mx.nd.array([1]))
-            assert(False)
-        except:
-            pass
-
-    def check_invalid_list_kv_pair(kv, key):
+
+    def check_ignored_pull_list(kv, key):
         dns_val = [mx.nd.ones(shape) * 2] * len(key)
         rsp_val = [val.tostype('row_sparse') for val in dns_val]
         kv.pull(key, out=rsp_val)
         for v in rsp_val:
-            # pull should be ignored with no values updated
             check_diff_to_scalar(v, 2)
-        try:
-            # row_sparse_pull should be aborted when vals.stype != row_sparse
-            kv.row_sparse_pull(key, out=dns_val, rowids=[mx.nd.array([1])] * len(key))
-            assert(False)
-        except:
-            pass
+
+    def check_invalid_rsp_pull_single(kv, key):
+        dns_val = mx.nd.ones(shape) * 2
+        assert_exception(kv.row_sparse_pull, key, out=dns_val, row_ids=mx.nd.array([1]))
+
+    def check_invalid_rsp_pull_list(kv, key):
+        dns_val = [mx.nd.ones(shape) * 2] * len(key)
+        assert_exception(kv.row_sparse_pull, key, out=dns_val,
+                         row_ids=[mx.nd.array([1])] * len(key))
+
+    def check_invalid_key_types_single(kv, key):
+        dns_val = mx.nd.ones(shape) * 2
+        rsp_val = dns_val.tostype('row_sparse')
+        assert_exception(kv.init, key, dns_val)
+        assert_exception(kv.push, key, dns_val)
+        assert_exception(kv.pull, key, dns_val)
+        assert_exception(kv.row_sparse_pull, key, rsp_val,
+                         row_ids=mx.nd.array([1]))
+
+    def check_invalid_key_types_list(kv, key):
+        dns_val = [mx.nd.ones(shape) * 2] * len(key)
+        rsp_val = [val.tostype('row_sparse') for val in dns_val]
+        assert_exception(kv.init, key, dns_val)
+        assert_exception(kv.push, key, dns_val)
+        assert_exception(kv.pull, key, dns_val)
+        assert_exception(kv.row_sparse_pull, key, rsp_val,
+                         row_ids=[mx.nd.array([1])] * len(key))
 
     int_kv = init_kv()
     str_kv = init_kv_with_str()
 
-    check_invalid_single_kv_pair(int_kv, 3)
-    check_invalid_single_kv_pair(str_kv, 'a')
-
-    check_invalid_list_kv_pair(int_kv, keys)
-    check_invalid_list_kv_pair(str_kv, str_keys)
+    kvs = [int_kv, str_kv]
+    single_keys = [3, 'a']
+    list_keys = [keys, str_keys]
+    for i in range(2):
+        # pull with rsp outputs should be ignored with no values updated
+        check_ignored_pull_single(kvs[i], single_keys[i])
+        check_ignored_pull_list(kvs[i], list_keys[i])
+        # row_sparse_pull should be aborted when vals.stype != row_sparse
+        check_invalid_rsp_pull_single(kvs[i], single_keys[i])
+        check_invalid_rsp_pull_list(kvs[i], list_keys[i])
+        # kvstore should be restricted to only accept either int or str keys
+        check_invalid_key_types_single(kvs[i], single_keys[1 - i])
+        check_invalid_key_types_list(kvs[i], list_keys[1 - i])
 
 if __name__ == '__main__':
-    test_init()
-    test_get_type()
-    test_single_kv_pair()
-    test_list_kv_pair()
-    test_sparse_aggregator()
-    test_aggregator()
-    test_updater()
-    test_row_sparse_pull()
+    import nose
+    nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].