You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/06/02 06:52:53 UTC

[incubator-mxnet] branch master updated: update scala api to properly track reshaped size (#11009)

This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7eb78d8  update scala api to properly track reshaped size (#11009)
7eb78d8 is described below

commit 7eb78d8833dc11f42f43e1a665da713ef6db7c81
Author: Jesse Brizzi <je...@users.noreply.github.com>
AuthorDate: Sat Jun 2 02:52:36 2018 -0400

    update scala api to properly track reshaped size (#11009)
---
 CONTRIBUTORS.md                                              |  1 +
 .../org/apache/mxnet/module/DataParallelExecutorGroup.scala  | 11 +++++++----
 .../core/src/main/scala/org/apache/mxnet/module/Module.scala |  2 ++
 .../core/src/test/scala/org/apache/mxnet/ModuleSuite.scala   | 12 ++++++++++++
 4 files changed, 22 insertions(+), 4 deletions(-)

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 4bfafb6..f1ab129 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -170,4 +170,5 @@ List of Contributors
 * [Sina Afrooze](https://github.com/safrooze)
 * [Sergey Sokolov](https://github.com/Ishitori)
 * [Thomas Delteil](https://github.com/ThomasDelteil)
+* [Jesse Brizzi](https://github.com/jessebrizzi)
 * [Hang Zhang](http://hangzh.com)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
index c13ebcd..1494dc8 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
@@ -266,8 +266,8 @@ class DataParallelExecutorGroup private[module](
     symbol: Symbol,
     contexts: Array[Context],
     workLoadList: IndexedSeq[Float],
-    dataShapes: IndexedSeq[DataDesc],
-    labelShapes: Option[IndexedSeq[DataDesc]] = None,
+    var dataShapes: IndexedSeq[DataDesc],
+    var labelShapes: Option[IndexedSeq[DataDesc]] = None,
     private[module] val paramNames: IndexedSeq[String],
     forTraining: Boolean,
     inputsNeedGrad: Boolean,
@@ -356,7 +356,7 @@ class DataParallelExecutorGroup private[module](
    * @param sharedGroup
    * @param reshape
    */
-  def bindExec(dataShapes: Seq[DataDesc], labelShapes: Option[Seq[DataDesc]],
+  def bindExec(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]],
                sharedGroup: Option[DataParallelExecutorGroup], reshape: Boolean = false): Unit = {
     this.batchSize = -1
     dataLayouts = decideSlices(dataShapes)
@@ -379,6 +379,9 @@ class DataParallelExecutorGroup private[module](
       ).toArray
     }
 
+    this.dataShapes = dataShapes
+    this.labelShapes = labelShapes
+
     // convenient data structures
     dataArrays = dataShapes.map(dataDesc =>
       this.execs.zipWithIndex.map { case (e, i) => (this.slices(i), e.argDict(dataDesc.name)) }
@@ -427,7 +430,7 @@ class DataParallelExecutorGroup private[module](
    * @param dataShapes
    * @param labelShapes
    */
-  def reshape(dataShapes: Seq[DataDesc], labelShapes: Option[Seq[DataDesc]]): Unit = {
+  def reshape(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]]): Unit = {
     if (!(dataShapes == this.dataShapes && labelShapes == this.labelShapes)) {
       if (this._defaultExecs == null) {
         this._defaultExecs = this.execs.map(x => x)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
index d55a426..9cf64b1 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
@@ -341,6 +341,8 @@ class Module(symbolVar: Symbol,
     require(this.binded)
     val (tdataShapes, tlabelShapes) = this._parseDataDesc(
       this.dataNames, this.labelNames, dataShapes, labelShapes)
+    this.dataShapesVar = tdataShapes
+    this.labelShapesVar = tlabelShapes
     this.execGroup.reshape(tdataShapes, tlabelShapes)
   }
 
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
index 22b9c3b..8234568 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
@@ -157,6 +157,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
     assert(mod.getOutputsMerged()(0).shape == dShape)
     assert(mod.getParams._1("fc_bias").toArray.forall(_ == -1f))
 
+    // reshape module
     dShape = Shape(14, 20)
     mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = "NT")))
     mod.forward(new DataBatch(
@@ -166,6 +167,17 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
     mod.update()
     assert(mod.getOutputsMerged()(0).shape == dShape)
     assert(mod.getParams._1("fc_bias").toArray.forall(x => (x - -3f) < 1e-3))
+
+    // return to original binded shape
+    dShape = Shape(7, 20)
+    mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = "NT")))
+    mod.forward(new DataBatch(
+      data = IndexedSeq(NDArray.ones(dShape)),
+      label = null, index = null, pad = 0))
+    mod.backward(Array(NDArray.ones(dShape)))
+    mod.update()
+    assert(mod.getOutputsMerged()(0).shape == dShape)
+    assert(mod.getParams._1("fc_bias").toArray.forall(x => (x - -3f) < 1e-3))
   }
 
   test ("module setParams") {

-- 
To stop receiving notification emails like this one, please contact
nswamy@apache.org.