You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by lx...@apache.org on 2017/07/07 15:58:18 UTC
[05/50] [abbrv] incubator-mxnet-test git commit: [Scala] support str
key type in kvstore (#6829)
[Scala] support str key type in kvstore (#6829)
Project: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/commit/5acc2c90
Tree: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/tree/5acc2c90
Diff: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/diff/5acc2c90
Branch: refs/heads/master
Commit: 5acc2c90d026015ad02b66839c19f72bdce35dfb
Parents: 9e66154
Author: 梁德澎 <li...@gmail.com>
Authored: Tue Jun 27 22:50:29 2017 +0800
Committer: Yizhi Liu <ja...@gmail.com>
Committed: Tue Jun 27 22:50:29 2017 +0800
----------------------------------------------------------------------
.../src/main/scala/ml/dmlc/mxnet/KVStore.scala | 30 ++++----
.../src/main/scala/ml/dmlc/mxnet/LibInfo.scala | 14 ++++
.../src/main/scala/ml/dmlc/mxnet/Model.scala | 22 ++++--
.../scala/ml/dmlc/mxnet/module/Module.scala | 4 +-
.../test/scala/ml/dmlc/mxnet/KVStoreSuite.scala | 40 +++++++---
.../main/native/ml_dmlc_mxnet_native_c_api.cc | 77 +++++++++++++++++++-
6 files changed, 152 insertions(+), 35 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/5acc2c90/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
----------------------------------------------------------------------
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
index 32e0ace..94dd254 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
@@ -83,13 +83,13 @@ class KVStore(private[mxnet] val handle: KVStoreHandle) {
* @param keys The keys.
* @param values The values.
*/
- def init(keys: Array[Int], values: Array[NDArray]): Unit = {
+ def init(keys: Array[String], values: Array[NDArray]): Unit = {
require(keys.length == values.length, "len(keys) != len(values)")
val valuePtrs = values.map(_.handle)
- checkCall(_LIB.mxKVStoreInit(handle, keys.length, keys, valuePtrs))
+ checkCall(_LIB.mxKVStoreInitEx(handle, keys.length, keys, valuePtrs))
}
- def init(key: Int, value: NDArray): Unit = {
+ def init(key: String, value: NDArray): Unit = {
init(Array(key), Array(value))
}
@@ -107,24 +107,24 @@ class KVStore(private[mxnet] val handle: KVStoreHandle) {
* The higher the priority, the faster this action is likely
* to be executed before other push actions.
*/
- def push(keys: Array[Int], values: Array[NDArray], priority: Int): Unit = {
+ def push(keys: Array[String], values: Array[NDArray], priority: Int): Unit = {
require(keys.length == values.length, "len(keys) != len(values)")
val valuePtrs = values.map(_.handle)
- checkCall(_LIB.mxKVStorePush(handle, keys.length, keys, valuePtrs, priority))
+ checkCall(_LIB.mxKVStorePushEx(handle, keys.length, keys, valuePtrs, priority))
}
- def push(keys: Array[Int], values: Array[NDArray]): Unit = push(keys, values, 0)
+ def push(keys: Array[String], values: Array[NDArray]): Unit = push(keys, values, 0)
- def push(key: Int, value: NDArray, priority: Int = 0): Unit = {
+ def push(key: String, value: NDArray, priority: Int = 0): Unit = {
push(Array(key), Array(value), priority)
}
- def push(key: Int, values: Array[NDArray], priority: Int): Unit = {
+ def push(key: String, values: Array[NDArray], priority: Int): Unit = {
val keys = Array.fill(values.length)(key)
push(keys, values, priority)
}
- def push(key: Int, values: Array[NDArray]): Unit = {
+ def push(key: String, values: Array[NDArray]): Unit = {
push(key, values, 0)
}
@@ -143,24 +143,24 @@ class KVStore(private[mxnet] val handle: KVStoreHandle) {
* The higher the priority, the faster this action is likely
* to be executed before other push actions.
*/
- def pull(keys: Array[Int], outs: Array[NDArray], priority: Int): Unit = {
+ def pull(keys: Array[String], outs: Array[NDArray], priority: Int): Unit = {
require(keys.length == outs.length, "len(keys) != len(outs)")
val outPtrs = outs.map(_.handle)
- checkCall(_LIB.mxKVStorePull(handle, keys.length, keys, outPtrs, priority))
+ checkCall(_LIB.mxKVStorePullEx(handle, keys.length, keys, outPtrs, priority))
}
- def pull(keys: Array[Int], outs: Array[NDArray]): Unit = pull(keys, outs, 0)
+ def pull(keys: Array[String], outs: Array[NDArray]): Unit = pull(keys, outs, 0)
- def pull(key: Int, out: NDArray, priority: Int = 0): Unit = {
+ def pull(key: String, out: NDArray, priority: Int = 0): Unit = {
pull(Array(key), Array(out), priority)
}
- def pull(key: Int, outs: Array[NDArray], priority: Int): Unit = {
+ def pull(key: String, outs: Array[NDArray], priority: Int): Unit = {
val keys = Array.fill(outs.length)(key)
pull(keys, outs, priority)
}
- def pull(key: Int, outs: Array[NDArray]): Unit = {
+ def pull(key: String, outs: Array[NDArray]): Unit = {
pull(key, outs, 0)
}
http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/5acc2c90/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
----------------------------------------------------------------------
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
index 97ba815..a943e31 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
@@ -117,16 +117,30 @@ private[mxnet] class LibInfo {
len: MXUint,
keys: Array[Int],
values: Array[NDArrayHandle]): Int
+ @native def mxKVStoreInitEx(handle: KVStoreHandle,
+ len: MXUint,
+ keys: Array[String],
+ values: Array[NDArrayHandle]): Int
@native def mxKVStorePush(handle: KVStoreHandle,
len: MXUint,
keys: Array[Int],
values: Array[NDArrayHandle],
priority: Int): Int
+ @native def mxKVStorePushEx(handle: KVStoreHandle,
+ len: MXUint,
+ keys: Array[String],
+ values: Array[NDArrayHandle],
+ priority: Int): Int
@native def mxKVStorePull(handle: KVStoreHandle,
len: MXUint,
keys: Array[Int],
outs: Array[NDArrayHandle],
priority: Int): Int
+ @native def mxKVStorePullEx(handle: KVStoreHandle,
+ len: MXUint,
+ keys: Array[String],
+ outs: Array[NDArrayHandle],
+ priority: Int): Int
@native def mxKVStoreSetUpdater(handle: KVStoreHandle, updaterFunc: MXKVStoreUpdater): Int
@native def mxKVStoreIsWorkerNode(isWorker: RefInt): Int
@native def mxKVStoreGetType(handle: KVStoreHandle, kvType: RefString): Int
http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/5acc2c90/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
----------------------------------------------------------------------
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
index 69fe682..81ff1cf 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
@@ -163,9 +163,10 @@ object Model {
require(paramArrays.length == paramNames.length)
for (idx <- 0 until paramArrays.length) {
val paramOnDevs = paramArrays(idx)
- kvStore.init(idx, argParams(paramNames(idx)))
+ val name = paramNames(idx)
+ kvStore.init(name, argParams(paramNames(idx)))
if (updateOnKVStore) {
- kvStore.pull(idx, paramOnDevs, -idx)
+ kvStore.pull(name, paramOnDevs, -idx)
}
}
}
@@ -173,13 +174,15 @@ object Model {
// Perform update of param_arrays from grad_arrays on kvstore
private[mxnet] def updateParamsOnKVStore(paramArrays: IndexedSeq[Array[NDArray]],
gradArrays: IndexedSeq[Array[NDArray]],
- kvStore: Option[KVStore]): Unit = {
+ kvStore: Option[KVStore],
+ paramNames: IndexedSeq[String]): Unit = {
(paramArrays zip gradArrays).zipWithIndex.foreach { case ((argList, gradList), index) =>
if (gradList != null) {
+ val name = paramNames(index)
// push gradient, priority is negative index
- kvStore.foreach(_.push(index, gradList, -index))
+ kvStore.foreach(_.push(name, gradList, -index))
// pull back the weights
- kvStore.foreach(_.pull(index, argList, -index))
+ kvStore.foreach(_.pull(name, argList, -index))
}
}
}
@@ -189,14 +192,16 @@ object Model {
gradArrays: IndexedSeq[Array[NDArray]],
updater: MXKVStoreUpdater,
numDevice: Int,
+ paramNames: IndexedSeq[String],
kvStore: Option[KVStore] = None) {
(paramArrays zip gradArrays).zipWithIndex.foreach { case ((argList, gradList), index) =>
if (gradList != null) {
kvStore.foreach(kv => {
+ val name = paramNames(index)
// push gradient, priority is negative index
- kv.push(index, gradList, -index)
+ kv.push(name, gradList, -index)
// pull back the sum gradients, to the same locations.
- kv.pull(index, gradList, -index)
+ kv.pull(name, gradList, -index)
})
(argList zip gradList).zipWithIndex.foreach { case ((w: NDArray, g: NDArray), k: Int) =>
// faked an index here, to make optimizer create diff
@@ -295,11 +300,12 @@ object Model {
if (updateOnKVStore) {
updateParamsOnKVStore(executorManager.paramArrays,
executorManager.gradArrays,
- kvStore)
+ kvStore, executorManager.paramNames)
} else {
updateParams(executorManager.paramArrays,
executorManager.gradArrays,
updaterLocal, ctx.length,
+ executorManager.paramNames,
kvStore)
}
monitor.foreach(_.tocPrint())
http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/5acc2c90/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
----------------------------------------------------------------------
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
index f0b8da0..2b1d743 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
@@ -365,11 +365,11 @@ class Module(symbolVar: Symbol,
paramsDirty = true
if (updateOnKVStore) {
Model.updateParamsOnKVStore(execGroup.paramArrays,
- execGroup.gradArrays, kvstore)
+ execGroup.gradArrays, kvstore, execGroup.paramNames)
} else {
require(updater != None)
Model.updateParams(execGroup.paramArrays,
- execGroup.gradArrays, updater.orNull, contexts.length, kvstore)
+ execGroup.gradArrays, updater.orNull, contexts.length, execGroup.paramNames, kvstore)
}
}
http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/5acc2c90/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
----------------------------------------------------------------------
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
index f024e8d..8df6d18 100644
--- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
@@ -25,8 +25,8 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
val shape = Shape(2, 1)
val ndArray = NDArray.zeros(shape)
- kv.init(3, NDArray.ones(shape))
- kv.pull(3, ndArray)
+ kv.init("3", NDArray.ones(shape))
+ kv.pull("3", ndArray)
assert(ndArray.toArray === Array(1f, 1f))
}
@@ -35,12 +35,34 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
val shape = Shape(2, 1)
val ndArray = NDArray.zeros(shape)
- kv.init(3, NDArray.ones(shape))
- kv.push(3, NDArray.ones(shape) * 4)
- kv.pull(3, ndArray)
+ kv.init("3", NDArray.ones(shape))
+ kv.push("3", NDArray.ones(shape) * 4)
+ kv.pull("3", ndArray)
assert(ndArray.toArray === Array(4f, 4f))
}
+ test("test aggregate") {
+ val shape = Shape(4, 4)
+ val keys = Array("b", "c", "d")
+ val kv = KVStore.create()
+ kv.init("a", NDArray.zeros(shape))
+ kv.init(keys, Array.fill(keys.length)(NDArray.zeros(shape)))
+ val numDevs = 4
+ val devs = (0 until numDevs).map(Context.cpu(_))
+ val vals = devs.map(d => NDArray.ones(shape, d)).toArray
+ kv.push("a", vals)
+ kv.pull("a", outs = vals)
+ assert(vals.map(v => v.toArray.map(x => x - numDevs).sum).sum == 0f)
+
+ val valss = keys.map { k =>
+ val tmpVals = devs.map(d => NDArray.ones(shape, d) * 2f).toArray
+ kv.push(k, tmpVals)
+ kv.pull(k, outs = tmpVals)
+ tmpVals
+ }.flatten
+ assert(valss.map(v => v.toArray.map(x => x - numDevs * 2f).sum).sum == 0f)
+ }
+
test("updater runs when push") {
val kv = KVStore.create()
val updater = new MXKVStoreUpdater {
@@ -57,12 +79,12 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
val shape = Shape(2, 1)
val ndArray = NDArray.zeros(shape)
- kv.init(3, NDArray.ones(shape) * 4)
- kv.pull(3, ndArray)
+ kv.init("3", NDArray.ones(shape) * 4)
+ kv.pull("3", ndArray)
assert(ndArray.toArray === Array(4f, 4f))
- kv.push(3, NDArray.ones(shape))
- kv.pull(3, ndArray)
+ kv.push("3", NDArray.ones(shape))
+ kv.pull("3", ndArray)
assert(ndArray.toArray === Array(6f, 6f))
}
http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/5acc2c90/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
----------------------------------------------------------------------
diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
index 65bf2b7..07fd075 100644
--- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
@@ -654,6 +654,30 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreInit
return ret;
}
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreInitEx
+ (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys, jlongArray values) {
+ const char **keyArray = new const char *[len];
+ for (int i = 0; i < len; i++) {
+ jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
+ const char *key = env->GetStringUTFChars(jkey, 0);
+ keyArray[i] = key;
+ env->DeleteLocalRef(jkey);
+ }
+ jlong *valueArray = env->GetLongArrayElements(values, NULL);
+ int ret = MXKVStoreInitEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
+ static_cast<mx_uint>(len),
+ keyArray,
+ reinterpret_cast<NDArrayHandle *>(valueArray));
+ env->ReleaseLongArrayElements(values, valueArray, 0);
+ for (int i = 0; i < len; i++) {
+ jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
+ env->ReleaseStringUTFChars(jkey, keyArray[i]);
+ env->DeleteLocalRef(jkey);
+ }
+ delete[] keyArray;
+ return ret;
+}
+
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePush
(JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys,
jlongArray values, jint priority) {
@@ -664,11 +688,36 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePush
static_cast<const int *>(keyArray),
reinterpret_cast<NDArrayHandle *>(valueArray),
priority);
- env->ReleaseIntArrayElements(keys, keyArray, 0);
env->ReleaseLongArrayElements(values, valueArray, 0);
return ret;
}
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePushEx
+ (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys,
+ jlongArray values, jint priority) {
+ const char **keyArray = new const char *[len];
+ for (int i = 0; i < len; i++) {
+ jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
+ const char *key = env->GetStringUTFChars(jkey, 0);
+ keyArray[i] = key;
+ env->DeleteLocalRef(jkey);
+ }
+ jlong *valueArray = env->GetLongArrayElements(values, NULL);
+ int ret = MXKVStorePushEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
+ static_cast<mx_uint>(len),
+ keyArray,
+ reinterpret_cast<NDArrayHandle *>(valueArray),
+ priority);
+ env->ReleaseLongArrayElements(values, valueArray, 0);
+ for (int i = 0; i < len; i++) {
+ jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
+ env->ReleaseStringUTFChars(jkey, keyArray[i]);
+ env->DeleteLocalRef(jkey);
+ }
+ delete[] keyArray;
+ return ret;
+}
+
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull
(JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys,
jlongArray outs, jint priority) {
@@ -684,6 +733,32 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull
return ret;
}
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePullEx
+ (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys,
+ jlongArray outs, jint priority) {
+ const char **keyArray = new const char *[len];
+ for (int i = 0; i < len; i++) {
+ jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
+ const char *key = env->GetStringUTFChars(jkey, 0);
+ keyArray[i] = key;
+ env->DeleteLocalRef(jkey);
+ }
+ jlong *outArray = env->GetLongArrayElements(outs, NULL);
+ int ret = MXKVStorePullEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
+ static_cast<mx_uint>(len),
+ keyArray,
+ reinterpret_cast<NDArrayHandle *>(outArray),
+ priority);
+ env->ReleaseLongArrayElements(outs, outArray, 0);
+ for (int i = 0; i < len; i++) {
+ jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
+ env->ReleaseStringUTFChars(jkey, keyArray[i]);
+ env->DeleteLocalRef(jkey);
+ }
+ delete[] keyArray;
+ return ret;
+}
+
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetType
(JNIEnv *env, jobject obj, jlong kvStorePtr, jobject kvType) {
const char *type;