You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/30 18:02:03 UTC
[incubator-mxnet] branch master updated: Add string interface to
updater to make it consistent with kvstore (#7585)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 470b564 Add string interface to updater to make it consistent with kvstore (#7585)
470b564 is described below
commit 470b56437290b33fef51b067c96fdce08cefa584
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Wed Aug 30 11:02:00 2017 -0700
Add string interface to updater to make it consistent with kvstore (#7585)
* str kv updater draft
* backward compatibility for other languages
* add capi MXKVStoreSetUpdaterEx
* fix nightly testkvstore test
* convert c_char_p/byte to str for python3
* add key type restriction to backend
* add test to check mixed key types
* remvoe nested catch throw"
---
include/mxnet/c_api.h | 48 +++++++++++-
include/mxnet/kvstore.h | 24 +++++-
python/mxnet/kvstore.py | 100 ++++++++++++++++--------
python/mxnet/optimizer.py | 4 +
src/c_api/c_api.cc | 57 ++++++++++++--
src/kvstore/kvstore_local.h | 140 +++++++++++++++++++++++++---------
tests/nightly/test_kvstore.py | 4 +-
tests/python/unittest/test_kvstore.py | 96 +++++++++++++++--------
8 files changed, 359 insertions(+), 114 deletions(-)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index bba6190..ef9d31e 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1602,7 +1602,7 @@ MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle,
int priority);
/*!
- * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string.
+ * \brief pull a list of (key, value) pairs from the kvstore, where each key is an integer.
* The NDArray pulled back will be in row_sparse storage with only the specified
* row_ids present based row_ids (others rows are zeros).
* \param handle handle to the kvstore
@@ -1615,10 +1615,28 @@ MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle,
*/
MXNET_DLL int MXKVStorePullRowSparse(KVStoreHandle handle,
mx_uint num,
- const char** keys,
+ const int* keys,
NDArrayHandle* vals,
const NDArrayHandle* row_ids,
int priority);
+/*!
+ * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string.
+ * The NDArray pulled back will be in row_sparse storage with only the specified
+ * row_ids present based row_ids (others rows are zeros).
+ * \param handle handle to the kvstore
+ * \param num the number of key-value pairs
+ * \param keys the list of keys
+ * \param vals the list of values
+ * \param row_ids the list of row_id NDArrays
+ * \param priority the priority of the action
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXKVStorePullRowSparseEx(KVStoreHandle handle,
+ mx_uint num,
+ const char** keys,
+ NDArrayHandle* vals,
+ const NDArrayHandle* row_ids,
+ int priority);
/*!
* \brief user-defined updater for the kvstore
@@ -1633,7 +1651,19 @@ typedef void (MXKVStoreUpdater)(int key,
NDArrayHandle local,
void *handle);
/*!
- * \brief register an push updater
+ * \brief user-defined updater for the kvstore with string keys
+ * It's this updater's responsibility to delete \a recv and \a local
+ * \param the key
+ * \param recv the pushed value on this key
+ * \param local the value stored on local on this key
+ * \param handle The additional handle to the updater
+ */
+typedef void (MXKVStoreStrUpdater)(const char* key,
+ NDArrayHandle recv,
+ NDArrayHandle local,
+ void *handle);
+/*!
+ * \brief register a push updater
* \param handle handle to the KVStore
* \param updater udpater function
* \param updater_handle The additional handle used to invoke the updater
@@ -1643,6 +1673,18 @@ MXNET_DLL int MXKVStoreSetUpdater(KVStoreHandle handle,
MXKVStoreUpdater updater,
void *updater_handle);
/*!
+ * \brief register a push updater with int keys and one with string keys
+ * \param handle handle to the KVStore
+ * \param updater updater function with int keys
+ * \param str_updater updater function with string keys
+ * \param updater_handle The additional handle used to invoke the updater
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXKVStoreSetUpdaterEx(KVStoreHandle handle,
+ MXKVStoreUpdater updater,
+ MXKVStoreStrUpdater str_updater,
+ void *updater_handle);
+/*!
* \brief get the type of the kvstore
* \param handle handle to the KVStore
* \param type a string type
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index 9ea63b4..bca88a5 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -202,6 +202,10 @@ class KVStore {
* \brief the prototype of user-defined updater
*/
typedef std::function<void(int, const NDArray&, NDArray*)> Updater;
+ /**
+ * \brief the prototype of user-defined updater with string keys
+ */
+ typedef std::function<void(const std::string&, const NDArray&, NDArray*)> StrUpdater;
/*!
* \brief set an updater
*
@@ -215,6 +219,19 @@ class KVStore {
CHECK(updater) << "invalid updater";
updater_ = updater;
}
+ /*!
+ * \brief set an updater with string keys
+ *
+ * Given a string key, assume \a x is the received (pushed) value and \a y is the
+ * value stored on the store node. The store updates \a y by `h(x, &y)`. The
+ * default \a h is ASSIGN, namely `*y = x`.
+ *
+ * \param updater user-defined string updater, default is assign
+ */
+ virtual void set_updater(const StrUpdater& updater) {
+ CHECK(updater) << "invalid updater";
+ str_updater_ = updater;
+ }
/******************************************************
* the following are used for multi-machines.
@@ -356,11 +373,16 @@ class KVStore {
protected:
/**
- * \brief the user-defined updater
+ * \brief the user-defined updater
*/
Updater updater_;
/**
+ * \brief the user-defined updater with string keys
+ */
+ StrUpdater str_updater_;
+
+ /**
* \brief the kvstore type
*/
std::string type_;
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index 2af70e3..bc034c5 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -29,26 +29,39 @@ from .base import NDArrayHandle, KVStoreHandle
from . import optimizer as opt
def _ctype_key_value(keys, vals):
+ """
+ Returns ctype arrays for the key-value args, and the whether string keys are used.
+ For internal use only.
+ """
if isinstance(keys, (tuple, list)):
assert(len(keys) == len(vals))
c_keys = []
c_vals = []
+ use_str_keys = None
for key, val in zip(keys, vals):
- c_key_i, c_val_i = _ctype_key_value(key, val)
+ c_key_i, c_val_i, str_keys_i = _ctype_key_value(key, val)
c_keys += c_key_i
c_vals += c_val_i
- return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals))
- names = []
- keys = str(keys)
+ use_str_keys = str_keys_i if use_str_keys is None else use_str_keys
+ assert(use_str_keys == str_keys_i), "inconsistent types of keys detected."
+ c_keys_arr = c_array(ctypes.c_char_p, c_keys) if use_str_keys \
+ else c_array(ctypes.c_int, c_keys)
+ c_vals_arr = c_array(NDArrayHandle, c_vals)
+ return (c_keys_arr, c_vals_arr, use_str_keys)
+
+ assert(isinstance(keys, (int,) + string_types)), \
+ "unexpected type for keys: " + str(type(keys))
+ use_str_keys = isinstance(keys, string_types)
if isinstance(vals, NDArray):
- names.append(c_str(keys))
- return (c_array(ctypes.c_char_p, names),
- c_array(NDArrayHandle, [vals.handle]))
+ c_keys = c_array(ctypes.c_char_p, [c_str(keys)]) if use_str_keys \
+ else c_array(ctypes.c_int, [keys])
+ return (c_keys, c_array(NDArrayHandle, [vals.handle]), use_str_keys)
else:
for value in vals:
assert(isinstance(value, NDArray))
- return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)),
- c_array(NDArrayHandle, [value.handle for value in vals]))
+ c_keys = c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)) if use_str_keys \
+ else c_array(ctypes.c_int, [keys] * len(vals))
+ return (c_keys, c_array(NDArrayHandle, [value.handle for value in vals]), use_str_keys)
def _updater_wrapper(updater):
"""A wrapper for the user-defined handle."""
@@ -74,6 +87,7 @@ class KVStore(object):
self.handle = handle
self._updater = None
self._updater_func = None
+ self._str_updater_func = None
def __del__(self):
check_call(_LIB.MXKVStoreFree(self.handle))
@@ -88,7 +102,7 @@ class KVStore(object):
Parameters
----------
- key : str or sequence of str
+ key : str, int, or sequence of str or int
The keys.
value : NDArray or sequence of NDArray
Values corresponding to the keys.
@@ -106,11 +120,14 @@ class KVStore(object):
[ 2. 2. 2.]]
>>> # init a list of key-value pairs
- >>> keys = ['5', '7', '9']
+ >>> keys = [5, 7, 9]
>>> kv.init(keys, [mx.nd.ones(shape)]*len(keys))
"""
- ckeys, cvals = _ctype_key_value(key, value)
- check_call(_LIB.MXKVStoreInitEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals))
+ ckeys, cvals, use_str_keys = _ctype_key_value(key, value)
+ if use_str_keys:
+ check_call(_LIB.MXKVStoreInitEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals))
+ else:
+ check_call(_LIB.MXKVStoreInit(self.handle, mx_uint(len(ckeys)), ckeys, cvals))
def push(self, key, value, priority=0):
""" Pushes a single or a sequence of key-value pairs into the store.
@@ -123,7 +140,7 @@ class KVStore(object):
Parameters
----------
- key : str or list of str
+ key : str, int, or sequence of str or int
Keys.
value : NDArray or list of NDArray or list of list of NDArray
@@ -154,6 +171,7 @@ class KVStore(object):
>>> # push a list of keys.
>>> # single device
+ >>> keys = [4, 5, 6]
>>> kv.push(keys, [mx.nd.ones(shape)]*len(keys))
>>> b = [mx.nd.zeros(shape)]*len(keys)
>>> kv.pull(keys, out=b)
@@ -162,6 +180,7 @@ class KVStore(object):
[ 1. 1. 1.]]
>>> # multiple devices:
+ >>> keys = ['7', '8', '9']
>>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys)
>>> kv.push(keys, b)
>>> kv.pull(keys, out=b)
@@ -169,10 +188,13 @@ class KVStore(object):
[[ 4. 4. 4.]
[ 4. 4. 4.]]
"""
- ckeys, cvals = _ctype_key_value(key, value)
- check_call(_LIB.MXKVStorePushEx(
- self.handle, mx_uint(len(ckeys)), ckeys, cvals,
- ctypes.c_int(priority)))
+ ckeys, cvals, use_str_keys = _ctype_key_value(key, value)
+ if use_str_keys:
+ check_call(_LIB.MXKVStorePushEx(
+ self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
+ else:
+ check_call(_LIB.MXKVStorePush(
+ self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
def pull(self, key, out=None, priority=0):
@@ -191,7 +213,7 @@ class KVStore(object):
Parameters
----------
- key : int or list of int
+ key : str, int, or sequence of str or int
Keys.
out: NDArray or list of NDArray or list of list of NDArray
@@ -220,13 +242,14 @@ class KVStore(object):
>>> # pull a list of key-value pairs.
>>> # On single device
- >>> keys = ['5', '7', '9']
+ >>> keys = [5, 7, 9]
>>> b = [mx.nd.zeros(shape)]*len(keys)
>>> kv.pull(keys, out=b)
>>> print b[1].asnumpy()
[[ 2. 2. 2.]
[ 2. 2. 2.]]
>>> # On multiple devices
+ >>> keys = ['6', '8', '10']
>>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys)
>>> kv.pull(keys, out=b)
>>> print b[1][1].asnumpy()
@@ -234,10 +257,13 @@ class KVStore(object):
[ 2. 2. 2.]]
"""
assert(out is not None)
- ckeys, cvals = _ctype_key_value(key, out)
- check_call(_LIB.MXKVStorePullEx(
- self.handle, mx_uint(len(ckeys)), ckeys, cvals,
- ctypes.c_int(priority)))
+ ckeys, cvals, use_str_keys = _ctype_key_value(key, out)
+ if use_str_keys:
+ check_call(_LIB.MXKVStorePullEx(
+ self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
+ else:
+ check_call(_LIB.MXKVStorePull(
+ self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
""" Pulls a single row_sparse value or a sequence of row_sparse values from the store
@@ -250,7 +276,7 @@ class KVStore(object):
Parameters
----------
- key : str or list of str
+ key : str, int, or sequence of str or int
Keys.
out: NDArray or list of NDArray or list of list of NDArray
@@ -291,12 +317,16 @@ class KVStore(object):
"""
assert(out is not None)
assert(row_ids is not None)
- ckeys, cvals = _ctype_key_value(key, out)
- _, crow_ids = _ctype_key_value(key, row_ids)
- assert(len(crow_ids) == len(cvals)), "number of row_ids doesn't match number of values"
-
- check_call(_LIB.MXKVStorePullRowSparse(
- self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority)))
+ ckeys, cvals, use_str_keys = _ctype_key_value(key, out)
+ _, crow_ids, _ = _ctype_key_value(key, row_ids)
+ assert(len(crow_ids) == len(cvals)), \
+ "the number of row_ids doesn't match the number of values"
+ if use_str_keys:
+ check_call(_LIB.MXKVStorePullRowSparseEx(
+ self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority)))
+ else:
+ check_call(_LIB.MXKVStorePullRowSparse(
+ self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority)))
def set_optimizer(self, optimizer):
@@ -436,10 +466,16 @@ class KVStore(object):
[ 6. 6. 6.]]
"""
self._updater = updater
+ # set updater with int keys
_updater_proto = ctypes.CFUNCTYPE(
None, ctypes.c_int, NDArrayHandle, NDArrayHandle, ctypes.c_void_p)
self._updater_func = _updater_proto(_updater_wrapper(updater))
- check_call(_LIB.MXKVStoreSetUpdater(self.handle, self._updater_func, None))
+ # set updater with str keys
+ _str_updater_proto = ctypes.CFUNCTYPE(
+ None, ctypes.c_char_p, NDArrayHandle, NDArrayHandle, ctypes.c_void_p)
+ self._str_updater_func = _str_updater_proto(_updater_wrapper(updater))
+ check_call(_LIB.MXKVStoreSetUpdaterEx(self.handle, self._updater_func,
+ self._str_updater_func, None))
def _barrier(self):
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index e7e283f..099d2b7 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -21,6 +21,7 @@ import pickle
import logging
import warnings
import numpy
+from .base import py_str
from .ndarray import (NDArray, zeros, clip, sqrt, sign, array, maximum, abs as NDabs)
from .ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update)
@@ -949,6 +950,9 @@ class Updater(object):
def __call__(self, index, grad, weight):
"""Updates weight given gradient and index."""
+ # convert ctypes.char_p.value back to python str if needed
+ if isinstance(index, bytes):
+ index = py_str(index)
if index not in self.states:
self.states[index] = self.optimizer.create_state(index, weight)
self.states_synced[index] = True
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 088e208..5171e27 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -829,12 +829,12 @@ int MXKVStorePullEx(KVStoreHandle handle,
int MXKVStorePullRowSparse(KVStoreHandle handle,
mx_uint num,
- const char** keys,
+ const int* keys,
NDArrayHandle* vals,
const NDArrayHandle* row_ids,
int priority) {
API_BEGIN();
- std::vector<std::string> v_keys(num);
+ std::vector<int> v_keys(num);
std::vector<std::pair<NDArray*, NDArray>> v_val_rowids(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
@@ -845,10 +845,27 @@ int MXKVStorePullRowSparse(KVStoreHandle handle,
API_END();
}
-int MXKVStoreSetUpdater(KVStoreHandle handle,
- MXKVStoreUpdater updater,
- void* updater_handle) {
+int MXKVStorePullRowSparseEx(KVStoreHandle handle,
+ mx_uint num,
+ const char** keys,
+ NDArrayHandle* vals,
+ const NDArrayHandle* row_ids,
+ int priority) {
API_BEGIN();
+ std::vector<std::string> v_keys(num);
+ std::vector<std::pair<NDArray*, NDArray>> v_val_rowids(num);
+ for (mx_uint i = 0; i < num; ++i) {
+ v_keys[i] = keys[i];
+ v_val_rowids[i] = std::make_pair(static_cast<NDArray*>(vals[i]),
+ *static_cast<NDArray*>(row_ids[i]));
+ }
+ static_cast<KVStore*>(handle)->PullRowSparse(v_keys, v_val_rowids, priority);
+ API_END();
+}
+
+void MXKVStoreSetUpdaterImpl(KVStoreHandle handle,
+ MXKVStoreUpdater updater,
+ void* updater_handle) {
MXKVStoreUpdater * updater_temp = updater;
void* updater_handle_temp = updater_handle;
std::function<void(int, const NDArray&, NDArray*)> updt
@@ -860,6 +877,36 @@ int MXKVStoreSetUpdater(KVStoreHandle handle,
updater_temp(key, recv_copy, local_copy, updater_handle_temp);
};
static_cast<KVStore*>(handle)->set_updater(updt);
+}
+
+int MXKVStoreSetUpdater(KVStoreHandle handle,
+ MXKVStoreUpdater updater,
+ void* updater_handle) {
+ API_BEGIN();
+ MXKVStoreSetUpdaterImpl(handle, updater, updater_handle);
+ API_END();
+}
+
+int MXKVStoreSetUpdaterEx(KVStoreHandle handle,
+ MXKVStoreUpdater updater,
+ MXKVStoreStrUpdater str_updater,
+ void* updater_handle) {
+ API_BEGIN();
+ // set updater with int keys
+ MXKVStoreSetUpdaterImpl(handle, updater, updater_handle);
+ // set updater with string keys
+ MXKVStoreStrUpdater * updater_temp = str_updater;
+ void* updater_handle_temp = updater_handle;
+ std::function<void(const std::string&, const NDArray&, NDArray*)> updt
+ = [updater_temp, updater_handle_temp]
+ (const std::string& key, const NDArray& recv, NDArray* local) {
+ NDArray* recv_copy = new NDArray();
+ *recv_copy = recv;
+ NDArray* local_copy = new NDArray();
+ *local_copy = *local;
+ updater_temp(key.c_str(), recv_copy, local_copy, updater_handle_temp);
+ };
+ static_cast<KVStore*>(handle)->set_updater(updt);
API_END();
}
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 11d4b64..e05819b 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -36,6 +36,13 @@
namespace mxnet {
namespace kvstore {
+
+enum KeyType {
+ kUndefinedKey = -1,
+ kStringKey,
+ kIntKey
+};
+
/**
* \brief store data in local machine
*/
@@ -59,16 +66,13 @@ class KVStoreLocal : public KVStore {
void Init(const std::vector<int>& keys,
const std::vector<NDArray>& values) override {
- for (size_t i = 0; i < keys.size(); ++i) {
- CHECK(local_.find(keys[i]) == local_.end())
- << "duplicate init of key " << keys[i];
- local_[keys[i]] = values[i].Copy(pinned_ctx_);
- comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
- }
+ SetKeyType(kIntKey);
+ Init_(keys, values);
}
void Init(const std::vector<std::string>& str_keys,
const std::vector<NDArray>& values) override {
+ SetKeyType(kStringKey);
std::vector<int> keys(str_keys.size());
for (size_t i = 0; i < str_keys.size(); ++i) {
auto &str_key = str_keys[i];
@@ -76,18 +80,78 @@ class KVStoreLocal : public KVStore {
<< "duplicate init of key " << str_key;
auto key = next_str_key_++;
str_key_dict_[str_key] = key;
+ // record reverse mapping from int to string
+ reverse_str_key_dict_[key] = str_key;
keys[i] = key;
}
- Init(keys, values);
+ Init_(keys, values);
}
void Push(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority) override {
+ SetKeyType(kIntKey);
+ Push_(keys, values, priority);
+ }
+
+ void Pull(const std::vector<int>& keys,
+ const std::vector<NDArray*>& values,
+ int priority) override {
+ SetKeyType(kIntKey);
+ Pull_(keys, values, priority);
+ }
+
+ void PullRowSparse(const std::vector<int>& keys,
+ const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
+ int priority = 0) override {
+ SetKeyType(kIntKey);
+ PullRowSparse_(keys, val_rowids, priority);
+ }
+
+ void Push(const std::vector<std::string>& str_keys,
+ const std::vector<NDArray>& values,
+ int priority) override {
+ SetKeyType(kStringKey);
+ std::vector<int> keys(str_keys.size());
+ LookupKeys(str_keys, &keys);
+ Push_(keys, values, priority);
+ }
+
+ void Pull(const std::vector<std::string>& str_keys,
+ const std::vector<NDArray*>& values,
+ int priority) override {
+ SetKeyType(kStringKey);
+ std::vector<int> keys(str_keys.size());
+ LookupKeys(str_keys, &keys);
+ Pull_(keys, values, priority);
+ }
+
+ void PullRowSparse(const std::vector<std::string>& str_keys,
+ const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
+ const int priority = 0) override {
+ SetKeyType(kStringKey);
+ std::vector<int> keys(str_keys.size());
+ LookupKeys(str_keys, &keys);
+ PullRowSparse_(keys, val_rowids, priority);
+ }
+
+ private:
+ void Init_(const std::vector<int>& keys,
+ const std::vector<NDArray>& values) {
+ for (size_t i = 0; i < keys.size(); ++i) {
+ CHECK(local_.find(keys[i]) == local_.end())
+ << "duplicate init of key " << keys[i];
+ local_[keys[i]] = values[i].Copy(pinned_ctx_);
+ comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
+ }
+ }
+
+ void Push_(const std::vector<int>& keys,
+ const std::vector<NDArray>& values,
+ int priority) {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray> > grouped_vals;
GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals);
-
for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
const NDArray& merged = comm_->Reduce(key, grouped_vals[i], priority);
@@ -99,7 +163,18 @@ class KVStoreLocal : public KVStore {
local.ctx().dev_mask() == cpu::kDevMask) {
local = local.Copy(merged.ctx());
}
- updater_(key, merged, &local);
+ // call the updater with string keys
+ // if string keys are used and str_updater_ is available
+ // otherwise fallback to updater_ which uses int key interface
+ if (key_type_ == kStringKey && str_updater_ != nullptr) {
+ // TODO(haibin) CHECK(str_updater_ != nullptr) if use_str_key
+ // after all language bindings picks up string interface changes
+ const std::string &str_key = reverse_str_key_dict_[key];
+ // TODO(haibin) avoid reverse key lookup if use_str_key
+ str_updater_(str_key, merged, &local);
+ } else {
+ updater_(key, merged, &local);
+ }
} else {
if (merged.storage_type() != local.storage_type()) {
local = merged.Copy(local.ctx());
@@ -110,9 +185,9 @@ class KVStoreLocal : public KVStore {
}
}
- void Pull(const std::vector<int>& keys,
- const std::vector<NDArray*>& values,
- int priority) override {
+ void Pull_(const std::vector<int>& keys,
+ const std::vector<NDArray*>& values,
+ int priority) {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray*> > grouped_vals;
GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals);
@@ -125,9 +200,9 @@ class KVStoreLocal : public KVStore {
}
}
- void PullRowSparse(const std::vector<int>& keys,
- const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
- int priority = 0) override {
+ void PullRowSparse_(const std::vector<int>& keys,
+ const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
+ int priority = 0) {
std::vector<int> uniq_keys;
std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids;
GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids);
@@ -149,31 +224,16 @@ class KVStoreLocal : public KVStore {
}
}
- void Push(const std::vector<std::string>& str_keys,
- const std::vector<NDArray>& values,
- int priority) override {
- std::vector<int> keys(str_keys.size());
- LookupKeys(str_keys, &keys);
- Push(keys, values, priority);
- }
-
- void Pull(const std::vector<std::string>& str_keys,
- const std::vector<NDArray*>& values,
- int priority) override {
- std::vector<int> keys(str_keys.size());
- LookupKeys(str_keys, &keys);
- Pull(keys, values, priority);
- }
-
- void PullRowSparse(const std::vector<std::string>& str_keys,
- const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
- const int priority = 0) override {
- std::vector<int> keys(str_keys.size());
- LookupKeys(str_keys, &keys);
- PullRowSparse(keys, val_rowids, priority);
+ protected:
+ /**
+ * \brief set the key type of the kvstore if haven't already.
+ * If the key type is already defined, check if it matches the provided key type
+ */
+ void SetKeyType(const KeyType key_type) {
+ if (key_type_ == kUndefinedKey) key_type_ = key_type;
+ CHECK_EQ(key_type_, key_type) << "Mixed key types are not allowed";
}
- protected:
/**
* \brief group values on keys for push
*/
@@ -309,10 +369,14 @@ class KVStoreLocal : public KVStore {
std::unordered_map<int, NDArray> local_;
/// key mapping for string -> integer
std::unordered_map<std::string, int> str_key_dict_;
+ /// reverse key mapping for integer -> string
+ std::unordered_map<int, std::string> reverse_str_key_dict_;
/// the next available integer for string->int key mapping
int next_str_key_ = 0;
/// whether printed warning due to mismatch stype in each key
std::unordered_set<int> warnings_printed_;
+ /// whether int or string is used for keys
+ KeyType key_type_ = kUndefinedKey;
};
} // namespace kvstore
} // namespace mxnet
diff --git a/tests/nightly/test_kvstore.py b/tests/nightly/test_kvstore.py
index b39ec89..081bc9c 100644
--- a/tests/nightly/test_kvstore.py
+++ b/tests/nightly/test_kvstore.py
@@ -37,7 +37,7 @@ data = [[[np.random.random(s)*2-1 for i in range(nworker)] for s in shapes] for
def test_kvstore(kv_type):
print(kv_type)
kv = mx.kv.create(kv_type)
- kv.set_optimizer(mx.optimizer.create('test', lr))
+ kv.set_optimizer(mx.optimizer.create('test', rescale_grad=lr))
for k, s in zip(keys, shapes):
kv.init(k, mx.nd.zeros(s))
@@ -63,7 +63,7 @@ test_kvstore('local_allreduce_device')
def test_group_kvstore(kv_type):
print(kv_type)
kv = mx.kv.create(kv_type)
- kv.set_optimizer(mx.optimizer.create('test', lr))
+ kv.set_optimizer(mx.optimizer.create('test', rescale_grad=lr))
kv.init(keys, [mx.nd.zeros(s) for s in shapes])
res = [np.zeros(s) for s in shapes]
out = [[mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)] for s in shapes]
diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py
index a43b98a..20ad2cd 100644
--- a/tests/python/unittest/test_kvstore.py
+++ b/tests/python/unittest/test_kvstore.py
@@ -19,11 +19,19 @@
import mxnet as mx
import numpy as np
from mxnet.test_utils import rand_ndarray, assert_almost_equal
+from mxnet.base import py_str
shape = (4, 4)
keys = [5, 7, 11]
str_keys = ['b', 'c', 'd']
+def assert_exception(f, *args, **kwargs):
+ try:
+ f(*args, **kwargs)
+ assert(False)
+ except:
+ return
+
def init_kv(stype='default'):
"""init kv """
kv = mx.kv.create()
@@ -180,9 +188,16 @@ def test_sparse_aggregator():
assert_almost_equal(result_sum, expected_sum * num_devs)
def updater(key, recv, local):
- """use updater: +="""
+ """use updater: += with int keys"""
+ assert(isinstance(key, int))
local += recv
+def str_updater(key, recv, local):
+ """use updater: += with str keys"""
+ if isinstance(key, bytes):
+ key = py_str(key)
+ assert(isinstance(key, str))
+ local += recv
def test_updater(dev = 'cpu'):
"""updater"""
@@ -219,7 +234,7 @@ def test_updater(dev = 'cpu'):
check_updater(kv, 3, keys)
str_kv = init_kv_with_str()
- str_kv._set_updater(updater)
+ str_kv._set_updater(str_updater)
check_updater(str_kv, 'a', str_keys)
def test_get_type():
@@ -228,48 +243,63 @@ def test_get_type():
assert kv.type == kvtype
def test_invalid_pull():
- def check_invalid_single_kv_pair(kv, key):
- dns_val = mx.nd.ones(shape) * 2
+ def check_ignored_pull_single(kv, key):
+ dns_val = (mx.nd.ones(shape) * 2)
rsp_val = dns_val.tostype('row_sparse')
kv.pull(key, out=rsp_val)
- # pull should be ignored with no values updated
check_diff_to_scalar(rsp_val, 2)
- try:
- # row_sparse_pull should be aborted when vals.stype != row_sparse
- kv.row_sparse_pull(key, out=dns_val, rowids=mx.nd.array([1]))
- assert(False)
- except:
- pass
-
- def check_invalid_list_kv_pair(kv, key):
+
+ def check_ignored_pull_list(kv, key):
dns_val = [mx.nd.ones(shape) * 2] * len(key)
rsp_val = [val.tostype('row_sparse') for val in dns_val]
kv.pull(key, out=rsp_val)
for v in rsp_val:
- # pull should be ignored with no values updated
check_diff_to_scalar(v, 2)
- try:
- # row_sparse_pull should be aborted when vals.stype != row_sparse
- kv.row_sparse_pull(key, out=dns_val, rowids=[mx.nd.array([1])] * len(key))
- assert(False)
- except:
- pass
+
+ def check_invalid_rsp_pull_single(kv, key):
+ dns_val = mx.nd.ones(shape) * 2
+ assert_exception(kv.row_sparse_pull, key, out=dns_val, row_ids=mx.nd.array([1]))
+
+ def check_invalid_rsp_pull_list(kv, key):
+ dns_val = [mx.nd.ones(shape) * 2] * len(key)
+ assert_exception(kv.row_sparse_pull, key, out=dns_val,
+ row_ids=[mx.nd.array([1])] * len(key))
+
+ def check_invalid_key_types_single(kv, key):
+ dns_val = mx.nd.ones(shape) * 2
+ rsp_val = dns_val.tostype('row_sparse')
+ assert_exception(kv.init, key, dns_val)
+ assert_exception(kv.push, key, dns_val)
+ assert_exception(kv.pull, key, dns_val)
+ assert_exception(kv.row_sparse_pull, key, rsp_val,
+ row_ids=mx.nd.array([1]))
+
+ def check_invalid_key_types_list(kv, key):
+ dns_val = [mx.nd.ones(shape) * 2] * len(key)
+ rsp_val = [val.tostype('row_sparse') for val in dns_val]
+ assert_exception(kv.init, key, dns_val)
+ assert_exception(kv.push, key, dns_val)
+ assert_exception(kv.pull, key, dns_val)
+ assert_exception(kv.row_sparse_pull, key, rsp_val,
+ row_ids=[mx.nd.array([1])] * len(key))
int_kv = init_kv()
str_kv = init_kv_with_str()
- check_invalid_single_kv_pair(int_kv, 3)
- check_invalid_single_kv_pair(str_kv, 'a')
-
- check_invalid_list_kv_pair(int_kv, keys)
- check_invalid_list_kv_pair(str_kv, str_keys)
+ kvs = [int_kv, str_kv]
+ single_keys = [3, 'a']
+ list_keys = [keys, str_keys]
+ for i in range(2):
+ # pull with rsp outputs should be ignored with no values updated
+ check_ignored_pull_single(kvs[i], single_keys[i])
+ check_ignored_pull_list(kvs[i], list_keys[i])
+ # row_sparse_pull should be aborted when vals.stype != row_sparse
+ check_invalid_rsp_pull_single(kvs[i], single_keys[i])
+ check_invalid_rsp_pull_list(kvs[i], list_keys[i])
+ # kvstore should be restricted to only accept either int or str keys
+ check_invalid_key_types_single(kvs[i], single_keys[1 - i])
+ check_invalid_key_types_list(kvs[i], list_keys[1 - i])
if __name__ == '__main__':
- test_init()
- test_get_type()
- test_single_kv_pair()
- test_list_kv_pair()
- test_sparse_aggregator()
- test_aggregator()
- test_updater()
- test_row_sparse_pull()
+ import nose
+ nose.runmodule()
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].