You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/06/02 06:52:53 UTC
[incubator-mxnet] branch master updated: update scala api to
properly track reshaped size (#11009)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 7eb78d8 update scala api to properly track reshaped size (#11009)
7eb78d8 is described below
commit 7eb78d8833dc11f42f43e1a665da713ef6db7c81
Author: Jesse Brizzi <je...@users.noreply.github.com>
AuthorDate: Sat Jun 2 02:52:36 2018 -0400
update scala api to properly track reshaped size (#11009)
---
CONTRIBUTORS.md | 1 +
.../org/apache/mxnet/module/DataParallelExecutorGroup.scala | 11 +++++++----
.../core/src/main/scala/org/apache/mxnet/module/Module.scala | 2 ++
.../core/src/test/scala/org/apache/mxnet/ModuleSuite.scala | 12 ++++++++++++
4 files changed, 22 insertions(+), 4 deletions(-)
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 4bfafb6..f1ab129 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -170,4 +170,5 @@ List of Contributors
* [Sina Afrooze](https://github.com/safrooze)
* [Sergey Sokolov](https://github.com/Ishitori)
* [Thomas Delteil](https://github.com/ThomasDelteil)
+* [Jesse Brizzi](https://github.com/jessebrizzi)
* [Hang Zhang](http://hangzh.com)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
index c13ebcd..1494dc8 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
@@ -266,8 +266,8 @@ class DataParallelExecutorGroup private[module](
symbol: Symbol,
contexts: Array[Context],
workLoadList: IndexedSeq[Float],
- dataShapes: IndexedSeq[DataDesc],
- labelShapes: Option[IndexedSeq[DataDesc]] = None,
+ var dataShapes: IndexedSeq[DataDesc],
+ var labelShapes: Option[IndexedSeq[DataDesc]] = None,
private[module] val paramNames: IndexedSeq[String],
forTraining: Boolean,
inputsNeedGrad: Boolean,
@@ -356,7 +356,7 @@ class DataParallelExecutorGroup private[module](
* @param sharedGroup
* @param reshape
*/
- def bindExec(dataShapes: Seq[DataDesc], labelShapes: Option[Seq[DataDesc]],
+ def bindExec(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]],
sharedGroup: Option[DataParallelExecutorGroup], reshape: Boolean = false): Unit = {
this.batchSize = -1
dataLayouts = decideSlices(dataShapes)
@@ -379,6 +379,9 @@ class DataParallelExecutorGroup private[module](
).toArray
}
+ this.dataShapes = dataShapes
+ this.labelShapes = labelShapes
+
// convenient data structures
dataArrays = dataShapes.map(dataDesc =>
this.execs.zipWithIndex.map { case (e, i) => (this.slices(i), e.argDict(dataDesc.name)) }
@@ -427,7 +430,7 @@ class DataParallelExecutorGroup private[module](
* @param dataShapes
* @param labelShapes
*/
- def reshape(dataShapes: Seq[DataDesc], labelShapes: Option[Seq[DataDesc]]): Unit = {
+ def reshape(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]]): Unit = {
if (!(dataShapes == this.dataShapes && labelShapes == this.labelShapes)) {
if (this._defaultExecs == null) {
this._defaultExecs = this.execs.map(x => x)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
index d55a426..9cf64b1 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
@@ -341,6 +341,8 @@ class Module(symbolVar: Symbol,
require(this.binded)
val (tdataShapes, tlabelShapes) = this._parseDataDesc(
this.dataNames, this.labelNames, dataShapes, labelShapes)
+ this.dataShapesVar = tdataShapes
+ this.labelShapesVar = tlabelShapes
this.execGroup.reshape(tdataShapes, tlabelShapes)
}
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
index 22b9c3b..8234568 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
@@ -157,6 +157,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
assert(mod.getOutputsMerged()(0).shape == dShape)
assert(mod.getParams._1("fc_bias").toArray.forall(_ == -1f))
+ // reshape module
dShape = Shape(14, 20)
mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = "NT")))
mod.forward(new DataBatch(
@@ -166,6 +167,17 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
mod.update()
assert(mod.getOutputsMerged()(0).shape == dShape)
assert(mod.getParams._1("fc_bias").toArray.forall(x => (x - -3f) < 1e-3))
+
+ // return to original binded shape
+ dShape = Shape(7, 20)
+ mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = "NT")))
+ mod.forward(new DataBatch(
+ data = IndexedSeq(NDArray.ones(dShape)),
+ label = null, index = null, pad = 0))
+ mod.backward(Array(NDArray.ones(dShape)))
+ mod.update()
+ assert(mod.getOutputsMerged()(0).shape == dShape)
+ assert(mod.getParams._1("fc_bias").toArray.forall(x => (x - -3f) < 1e-3))
}
test ("module setParams") {
--
To stop receiving notification emails like this one, please contact
nswamy@apache.org.