You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by lx...@apache.org on 2017/07/07 15:58:18 UTC

[05/50] [abbrv] incubator-mxnet-test git commit: [Scala] support str key type in kvstore (#6829)

[Scala] support str key type in kvstore (#6829)



Project: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/commit/5acc2c90
Tree: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/tree/5acc2c90
Diff: http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/diff/5acc2c90

Branch: refs/heads/master
Commit: 5acc2c90d026015ad02b66839c19f72bdce35dfb
Parents: 9e66154
Author: 梁德澎 <li...@gmail.com>
Authored: Tue Jun 27 22:50:29 2017 +0800
Committer: Yizhi Liu <ja...@gmail.com>
Committed: Tue Jun 27 22:50:29 2017 +0800

----------------------------------------------------------------------
 .../src/main/scala/ml/dmlc/mxnet/KVStore.scala  | 30 ++++----
 .../src/main/scala/ml/dmlc/mxnet/LibInfo.scala  | 14 ++++
 .../src/main/scala/ml/dmlc/mxnet/Model.scala    | 22 ++++--
 .../scala/ml/dmlc/mxnet/module/Module.scala     |  4 +-
 .../test/scala/ml/dmlc/mxnet/KVStoreSuite.scala | 40 +++++++---
 .../main/native/ml_dmlc_mxnet_native_c_api.cc   | 77 +++++++++++++++++++-
 6 files changed, 152 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/5acc2c90/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
----------------------------------------------------------------------
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
index 32e0ace..94dd254 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
@@ -83,13 +83,13 @@ class KVStore(private[mxnet] val handle: KVStoreHandle) {
    * @param keys The keys.
    * @param values The values.
    */
-  def init(keys: Array[Int], values: Array[NDArray]): Unit = {
+  def init(keys: Array[String], values: Array[NDArray]): Unit = {
     require(keys.length == values.length, "len(keys) != len(values)")
     val valuePtrs = values.map(_.handle)
-    checkCall(_LIB.mxKVStoreInit(handle, keys.length, keys, valuePtrs))
+    checkCall(_LIB.mxKVStoreInitEx(handle, keys.length, keys, valuePtrs))
   }
 
-  def init(key: Int, value: NDArray): Unit = {
+  def init(key: String, value: NDArray): Unit = {
     init(Array(key), Array(value))
   }
 
@@ -107,24 +107,24 @@ class KVStore(private[mxnet] val handle: KVStoreHandle) {
    *         The higher the priority, the faster this action is likely
    *         to be executed before other push actions.
    */
-  def push(keys: Array[Int], values: Array[NDArray], priority: Int): Unit = {
+  def push(keys: Array[String], values: Array[NDArray], priority: Int): Unit = {
     require(keys.length == values.length, "len(keys) != len(values)")
     val valuePtrs = values.map(_.handle)
-    checkCall(_LIB.mxKVStorePush(handle, keys.length, keys, valuePtrs, priority))
+    checkCall(_LIB.mxKVStorePushEx(handle, keys.length, keys, valuePtrs, priority))
   }
 
-  def push(keys: Array[Int], values: Array[NDArray]): Unit = push(keys, values, 0)
+  def push(keys: Array[String], values: Array[NDArray]): Unit = push(keys, values, 0)
 
-  def push(key: Int, value: NDArray, priority: Int = 0): Unit = {
+  def push(key: String, value: NDArray, priority: Int = 0): Unit = {
     push(Array(key), Array(value), priority)
   }
 
-  def push(key: Int, values: Array[NDArray], priority: Int): Unit = {
+  def push(key: String, values: Array[NDArray], priority: Int): Unit = {
     val keys = Array.fill(values.length)(key)
     push(keys, values, priority)
   }
 
-  def push(key: Int, values: Array[NDArray]): Unit = {
+  def push(key: String, values: Array[NDArray]): Unit = {
     push(key, values, 0)
   }
 
@@ -143,24 +143,24 @@ class KVStore(private[mxnet] val handle: KVStoreHandle) {
    *     The higher the priority, the faster this action is likely
    *     to be executed before other push actions.
    */
-  def pull(keys: Array[Int], outs: Array[NDArray], priority: Int): Unit = {
+  def pull(keys: Array[String], outs: Array[NDArray], priority: Int): Unit = {
     require(keys.length == outs.length, "len(keys) != len(outs)")
     val outPtrs = outs.map(_.handle)
-    checkCall(_LIB.mxKVStorePull(handle, keys.length, keys, outPtrs, priority))
+    checkCall(_LIB.mxKVStorePullEx(handle, keys.length, keys, outPtrs, priority))
   }
 
-  def pull(keys: Array[Int], outs: Array[NDArray]): Unit = pull(keys, outs, 0)
+  def pull(keys: Array[String], outs: Array[NDArray]): Unit = pull(keys, outs, 0)
 
-  def pull(key: Int, out: NDArray, priority: Int = 0): Unit = {
+  def pull(key: String, out: NDArray, priority: Int = 0): Unit = {
     pull(Array(key), Array(out), priority)
   }
 
-  def pull(key: Int, outs: Array[NDArray], priority: Int): Unit = {
+  def pull(key: String, outs: Array[NDArray], priority: Int): Unit = {
     val keys = Array.fill(outs.length)(key)
     pull(keys, outs, priority)
   }
 
-  def pull(key: Int, outs: Array[NDArray]): Unit = {
+  def pull(key: String, outs: Array[NDArray]): Unit = {
     pull(key, outs, 0)
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/5acc2c90/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
----------------------------------------------------------------------
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
index 97ba815..a943e31 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
@@ -117,16 +117,30 @@ private[mxnet] class LibInfo {
                             len: MXUint,
                             keys: Array[Int],
                             values: Array[NDArrayHandle]): Int
+  @native def mxKVStoreInitEx(handle: KVStoreHandle,
+                              len: MXUint,
+                              keys: Array[String],
+                              values: Array[NDArrayHandle]): Int
   @native def mxKVStorePush(handle: KVStoreHandle,
                             len: MXUint,
                             keys: Array[Int],
                             values: Array[NDArrayHandle],
                             priority: Int): Int
+  @native def mxKVStorePushEx(handle: KVStoreHandle,
+                              len: MXUint,
+                              keys: Array[String],
+                              values: Array[NDArrayHandle],
+                              priority: Int): Int
   @native def mxKVStorePull(handle: KVStoreHandle,
                             len: MXUint,
                             keys: Array[Int],
                             outs: Array[NDArrayHandle],
                             priority: Int): Int
+  @native def mxKVStorePullEx(handle: KVStoreHandle,
+                              len: MXUint,
+                              keys: Array[String],
+                              outs: Array[NDArrayHandle],
+                              priority: Int): Int
   @native def mxKVStoreSetUpdater(handle: KVStoreHandle, updaterFunc: MXKVStoreUpdater): Int
   @native def mxKVStoreIsWorkerNode(isWorker: RefInt): Int
   @native def mxKVStoreGetType(handle: KVStoreHandle, kvType: RefString): Int

http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/5acc2c90/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
----------------------------------------------------------------------
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
index 69fe682..81ff1cf 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
@@ -163,9 +163,10 @@ object Model {
     require(paramArrays.length == paramNames.length)
     for (idx <- 0 until paramArrays.length) {
       val paramOnDevs = paramArrays(idx)
-      kvStore.init(idx, argParams(paramNames(idx)))
+      val name = paramNames(idx)
+      kvStore.init(name, argParams(paramNames(idx)))
       if (updateOnKVStore) {
-        kvStore.pull(idx, paramOnDevs, -idx)
+        kvStore.pull(name, paramOnDevs, -idx)
       }
     }
   }
@@ -173,13 +174,15 @@ object Model {
   // Perform update of param_arrays from grad_arrays on kvstore
   private[mxnet] def updateParamsOnKVStore(paramArrays: IndexedSeq[Array[NDArray]],
                                            gradArrays: IndexedSeq[Array[NDArray]],
-                                           kvStore: Option[KVStore]): Unit = {
+                                           kvStore: Option[KVStore],
+                                           paramNames: IndexedSeq[String]): Unit = {
     (paramArrays zip gradArrays).zipWithIndex.foreach { case ((argList, gradList), index) =>
       if (gradList != null) {
+        val name = paramNames(index)
         // push gradient, priority is negative index
-        kvStore.foreach(_.push(index, gradList, -index))
+        kvStore.foreach(_.push(name, gradList, -index))
         // pull back the weights
-        kvStore.foreach(_.pull(index, argList, -index))
+        kvStore.foreach(_.pull(name, argList, -index))
       }
     }
   }
@@ -189,14 +192,16 @@ object Model {
                                   gradArrays: IndexedSeq[Array[NDArray]],
                                   updater: MXKVStoreUpdater,
                                   numDevice: Int,
+                                  paramNames: IndexedSeq[String],
                                   kvStore: Option[KVStore] = None) {
     (paramArrays zip gradArrays).zipWithIndex.foreach { case ((argList, gradList), index) =>
       if (gradList != null) {
         kvStore.foreach(kv => {
+          val name = paramNames(index)
           // push gradient, priority is negative index
-          kv.push(index, gradList, -index)
+          kv.push(name, gradList, -index)
           // pull back the sum gradients, to the same locations.
-          kv.pull(index, gradList, -index)
+          kv.pull(name, gradList, -index)
         })
         (argList zip gradList).zipWithIndex.foreach { case ((w: NDArray, g: NDArray), k: Int) =>
           // faked an index here, to make optimizer create diff
@@ -295,11 +300,12 @@ object Model {
           if (updateOnKVStore) {
             updateParamsOnKVStore(executorManager.paramArrays,
               executorManager.gradArrays,
-              kvStore)
+              kvStore, executorManager.paramNames)
           } else {
             updateParams(executorManager.paramArrays,
               executorManager.gradArrays,
               updaterLocal, ctx.length,
+              executorManager.paramNames,
               kvStore)
           }
           monitor.foreach(_.tocPrint())

http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/5acc2c90/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
----------------------------------------------------------------------
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
index f0b8da0..2b1d743 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
@@ -365,11 +365,11 @@ class Module(symbolVar: Symbol,
     paramsDirty = true
     if (updateOnKVStore) {
       Model.updateParamsOnKVStore(execGroup.paramArrays,
-        execGroup.gradArrays, kvstore)
+        execGroup.gradArrays, kvstore, execGroup.paramNames)
     } else {
       require(updater != None)
       Model.updateParams(execGroup.paramArrays,
-        execGroup.gradArrays, updater.orNull, contexts.length, kvstore)
+        execGroup.gradArrays, updater.orNull, contexts.length, execGroup.paramNames, kvstore)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/5acc2c90/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
----------------------------------------------------------------------
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
index f024e8d..8df6d18 100644
--- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
@@ -25,8 +25,8 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
     val shape = Shape(2, 1)
     val ndArray = NDArray.zeros(shape)
 
-    kv.init(3, NDArray.ones(shape))
-    kv.pull(3, ndArray)
+    kv.init("3", NDArray.ones(shape))
+    kv.pull("3", ndArray)
     assert(ndArray.toArray === Array(1f, 1f))
   }
 
@@ -35,12 +35,34 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
     val shape = Shape(2, 1)
     val ndArray = NDArray.zeros(shape)
 
-    kv.init(3, NDArray.ones(shape))
-    kv.push(3, NDArray.ones(shape) * 4)
-    kv.pull(3, ndArray)
+    kv.init("3", NDArray.ones(shape))
+    kv.push("3", NDArray.ones(shape) * 4)
+    kv.pull("3", ndArray)
     assert(ndArray.toArray === Array(4f, 4f))
   }
 
+  test("test aggregate") {
+    val shape = Shape(4, 4)
+    val keys = Array("b", "c", "d")
+    val kv = KVStore.create()
+    kv.init("a", NDArray.zeros(shape))
+    kv.init(keys, Array.fill(keys.length)(NDArray.zeros(shape)))
+    val numDevs = 4
+    val devs = (0 until numDevs).map(Context.cpu(_))
+    val vals = devs.map(d => NDArray.ones(shape, d)).toArray
+    kv.push("a", vals)
+    kv.pull("a", outs = vals)
+    assert(vals.map(v => v.toArray.map(x => x - numDevs).sum).sum == 0f)
+
+    val valss = keys.map { k =>
+      val tmpVals = devs.map(d => NDArray.ones(shape, d) * 2f).toArray
+      kv.push(k, tmpVals)
+      kv.pull(k, outs = tmpVals)
+      tmpVals
+    }.flatten
+    assert(valss.map(v => v.toArray.map(x => x - numDevs * 2f).sum).sum == 0f)
+  }
+
   test("updater runs when push") {
     val kv = KVStore.create()
     val updater = new MXKVStoreUpdater {
@@ -57,12 +79,12 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
     val shape = Shape(2, 1)
     val ndArray = NDArray.zeros(shape)
 
-    kv.init(3, NDArray.ones(shape) * 4)
-    kv.pull(3, ndArray)
+    kv.init("3", NDArray.ones(shape) * 4)
+    kv.pull("3", ndArray)
     assert(ndArray.toArray === Array(4f, 4f))
 
-    kv.push(3, NDArray.ones(shape))
-    kv.pull(3, ndArray)
+    kv.push("3", NDArray.ones(shape))
+    kv.pull("3", ndArray)
     assert(ndArray.toArray === Array(6f, 6f))
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-mxnet-test/blob/5acc2c90/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
----------------------------------------------------------------------
diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
index 65bf2b7..07fd075 100644
--- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
@@ -654,6 +654,30 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreInit
   return ret;
 }
 
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreInitEx
+  (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys, jlongArray values) {
+  const char **keyArray = new const char *[len];
+  for (int i = 0; i < len; i++) {
+    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
+    const char *key = env->GetStringUTFChars(jkey, 0);
+    keyArray[i] = key;
+    env->DeleteLocalRef(jkey);
+  }
+  jlong *valueArray = env->GetLongArrayElements(values, NULL);
+  int ret = MXKVStoreInitEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
+                          static_cast<mx_uint>(len),
+                          keyArray,
+                          reinterpret_cast<NDArrayHandle *>(valueArray));
+  env->ReleaseLongArrayElements(values, valueArray, 0);
+  for (int i = 0; i < len; i++) {
+    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
+    env->ReleaseStringUTFChars(jkey, keyArray[i]);
+    env->DeleteLocalRef(jkey);
+  }
+  delete[] keyArray;
+  return ret;
+}
+
 JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePush
   (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys,
     jlongArray values, jint priority) {
@@ -664,11 +688,36 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePush
                           static_cast<const int *>(keyArray),
                           reinterpret_cast<NDArrayHandle *>(valueArray),
                           priority);
-  env->ReleaseIntArrayElements(keys, keyArray, 0);
   env->ReleaseLongArrayElements(values, valueArray, 0);
   return ret;
 }
 
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePushEx
+  (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys,
+    jlongArray values, jint priority) {
+  const char **keyArray = new const char *[len];
+  for (int i = 0; i < len; i++) {
+    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
+    const char *key = env->GetStringUTFChars(jkey, 0);
+    keyArray[i] = key;
+    env->DeleteLocalRef(jkey);
+  }
+  jlong *valueArray = env->GetLongArrayElements(values, NULL);
+  int ret = MXKVStorePushEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
+                          static_cast<mx_uint>(len),
+                          keyArray,
+                          reinterpret_cast<NDArrayHandle *>(valueArray),
+                          priority);
+  env->ReleaseLongArrayElements(values, valueArray, 0);
+  for (int i = 0; i < len; i++) {
+    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
+    env->ReleaseStringUTFChars(jkey, keyArray[i]);
+    env->DeleteLocalRef(jkey);
+  }
+  delete[] keyArray;
+  return ret;
+}
+
 JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull
   (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys,
     jlongArray outs, jint priority) {
@@ -684,6 +733,32 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull
   return ret;
 }
 
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePullEx
+  (JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jobjectArray keys,
+    jlongArray outs, jint priority) {
+  const char **keyArray = new const char *[len];
+  for (int i = 0; i < len; i++) {
+    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
+    const char *key = env->GetStringUTFChars(jkey, 0);
+    keyArray[i] = key;
+    env->DeleteLocalRef(jkey);
+  }
+  jlong *outArray = env->GetLongArrayElements(outs, NULL);
+  int ret = MXKVStorePullEx(reinterpret_cast<KVStoreHandle>(kvStorePtr),
+                          static_cast<mx_uint>(len),
+                          keyArray,
+                          reinterpret_cast<NDArrayHandle *>(outArray),
+                          priority);
+  env->ReleaseLongArrayElements(outs, outArray, 0);
+  for (int i = 0; i < len; i++) {
+    jstring jkey = reinterpret_cast<jstring>(env->GetObjectArrayElement(keys, i));
+    env->ReleaseStringUTFChars(jkey, keyArray[i]);
+    env->DeleteLocalRef(jkey);
+  }
+  delete[] keyArray;
+  return ret;
+}
+
 JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetType
   (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject kvType) {
   const char *type;