You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by li...@apache.org on 2017/08/04 15:47:00 UTC

[incubator-mxnet] branch master updated: [Scala] Make Module Api sync with Python interface (#7246)

This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 14b83fc  [Scala] Make Module Api sync with Python interface (#7246)
14b83fc is described below

commit 14b83fccef7b96f8d38d780dbce3d0ef47267934
Author: 梁德澎 <li...@gmail.com>
AuthorDate: Fri Aug 4 23:46:51 2017 +0800

    [Scala] Make Module Api sync with Python interface (#7246)
    
    * [Scala] Make Module Api sync with Python interface
    
    * fix
---
 .../scala/ml/dmlc/mxnet/module/BaseModule.scala    |  38 ++-
 .../mxnet/module/DataParallelExecutorGroup.scala   |  53 ++-
 .../main/scala/ml/dmlc/mxnet/module/Module.scala   |  69 +++-
 .../ml/dmlc/mxnet/module/SequentialModule.scala    |  10 +-
 .../src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala | 368 +++++++++++++++++++++
 .../test/scala/ml/dmlc/mxnet/OperatorSuite.scala   |   2 +-
 6 files changed, 514 insertions(+), 26 deletions(-)

diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BaseModule.scala
index c1cb91d..0a73e1a 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BaseModule.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BaseModule.scala
@@ -121,6 +121,7 @@ abstract class BaseModule {
   private[module] var auxParams: Map[String, NDArray] = null
 
   // High Level API
+  def getSymbol: Symbol = this.symbol
 
   // A convenient function that calls both `forward` and `backward`.
   def forwardBackward(dataBatch: DataBatch): Unit = {
@@ -259,7 +260,7 @@ abstract class BaseModule {
   /**
    * Get parameters, those are potentially copies of the the actual parameters used
    * to do computation on the device.
-   * @return `(arg_params, aux_params)`, a pair of dictionary of name to value mapping.
+   * @return `(argParams, auxParams)`, a pair of dictionary of name to value mapping.
    */
   def getParams: (Map[String, NDArray], Map[String, NDArray])
 
@@ -267,41 +268,52 @@ abstract class BaseModule {
    * Initialize the parameters and auxiliary states.
    * @param initializer : Initializer
    *         Called to initialize parameters if needed.
-   *     arg_params : dict
+   *     argParams : dict
    *         If not None, should be a dictionary of existing arg_params. Initialization
    *         will be copied from that.
-   *     aux_params : dict
+   *     auxParams : dict
    *         If not None, should be a dictionary of existing aux_params. Initialization
    *         will be copied from that.
-   *     allow_missing : bool
+   *     allowMissing : bool
    *         If true, params could contain missing values, and the initializer will be
    *         called to fill those missing params.
-   *     force_init : bool
+   *     forceInit : bool
    *         If true, will force re-initialize even if already initialized.
+   *     allowExtra : bool
+   *         Whether allow extra parameters that are not needed by symbol.
+   *         If this is True, no error will be thrown when argParams or auxParams
+   *         contain extra parameters that is not needed by the executor.
    */
   def initParams(initializer: Initializer = new Uniform(0.01f),
                  argParams: Map[String, NDArray] = null,
                  auxParams: Map[String, NDArray] = null,
-                 allowMissing: Boolean = false, forceInit: Boolean = false): Unit
+                 allowMissing: Boolean = false,
+                 forceInit: Boolean = false,
+                 allowExtra: Boolean = false): Unit
 
   /**
    * Assign parameter and aux state values.
-   *     arg_params : dict
+   *     argParams : dict
    *         Dictionary of name to value (`NDArray`) mapping.
-   *     aux_params : dict
+   *     auxParams : dict
    *         Dictionary of name to value (`NDArray`) mapping.
-   *     allow_missing : bool
+   *     allowMissing : bool
    *         If true, params could contain missing values, and the initializer will be
    *         called to fill those missing params.
-   *     force_init : bool
+   *     forceInit : bool
    *         If true, will force re-initialize even if already initialized.
+   *     allowExtra : bool
+   *         Whether allow extra parameters that are not needed by symbol.
+   *         If this is True, no error will be thrown when argParams or auxParams
+   *         contain extra parameters that is not needed by the executor.
    */
   def setParams(argParams: Map[String, NDArray],
                 auxParams: Map[String, NDArray],
                 allowMissing: Boolean = false,
-                forceInit: Boolean = true): Unit = {
-    initParams(initializer = null, argParams = argParams, auxParams = auxParams,
-      allowMissing = allowMissing, forceInit = forceInit)
+                forceInit: Boolean = true,
+                allowExtra: Boolean = false): Unit = {
+    initParams(initializer = null, argParams, auxParams,
+      allowMissing, forceInit, allowExtra)
   }
 
   /**
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/DataParallelExecutorGroup.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/DataParallelExecutorGroup.scala
index 2e724c6..ea78962 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/DataParallelExecutorGroup.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/DataParallelExecutorGroup.scala
@@ -297,6 +297,7 @@ class DataParallelExecutorGroup private[module](
 
   private var batchSize: Int = -1
   private var slices: Array[(Int, Int)] = null
+  private var _defaultExecs: Array[Executor] = null
   private var execs: Array[Executor] = null
   private var dataArrays: Seq[Array[((Int, Int), NDArray)]] = null
   private var labelArrays: Option[Seq[Array[((Int, Int), NDArray)]]] = None
@@ -305,8 +306,8 @@ class DataParallelExecutorGroup private[module](
   private[module] var auxArrays: IndexedSeq[Array[NDArray]] = null
   private var inputGradArrays: IndexedSeq[Array[NDArray]] = null
 
-  private val dataLayouts = decideSlices(dataShapes)
-  private val labelLayouts =
+  private var dataLayouts = decideSlices(dataShapes)
+  private var labelLayouts =
     // call it to make sure labels has the same batch size as data
     if (labelShapes != None) decideSlices(labelShapes.get)
     else null
@@ -349,12 +350,30 @@ class DataParallelExecutorGroup private[module](
    * @param dataShapes DataDesc for input data.
    * @param labelShapes DataDesc for input labels.
    * @param sharedGroup
+   * @param reshape
    */
   def bindExec(dataShapes: Seq[DataDesc], labelShapes: Option[Seq[DataDesc]],
-               sharedGroup: Option[DataParallelExecutorGroup]): Unit = {
-    execs = (0 until contexts.length).map(i =>
-      bindIthExec(i, dataShapes, labelShapes, sharedGroup)
-    ).toArray
+               sharedGroup: Option[DataParallelExecutorGroup], reshape: Boolean = false): Unit = {
+    this.batchSize = -1
+    dataLayouts = decideSlices(dataShapes)
+    labelLayouts = {
+      // call it to make sure labels has the same batch size as data
+      if (labelShapes != None) decideSlices(labelShapes.get)
+      else null
+    }
+    if (reshape) {
+      (0 until contexts.length).foreach { i =>
+        val dataShapesSliced = slicedShape(dataShapes, i, dataLayouts)
+        val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts))
+        val inputShapes
+          = dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape])
+        execs(i) = _defaultExecs(i).reshape(allowUpSizing = true, kwargs = inputShapes)
+      }
+    } else {
+      execs = (0 until contexts.length).map(i =>
+        bindIthExec(i, dataShapes, labelShapes, sharedGroup)
+      ).toArray
+    }
 
     // convenient data structures
     dataArrays = dataShapes.map(dataDesc =>
@@ -400,12 +419,30 @@ class DataParallelExecutorGroup private[module](
   }
 
   /**
+   * Reshape executors.
+   * @param dataShapes
+   * @param labelShapes
+   */
+  def reshape(dataShapes: Seq[DataDesc], labelShapes: Option[Seq[DataDesc]]): Unit = {
+    if (!(dataShapes == this.dataShapes && labelShapes == this.labelShapes)) {
+      if (this._defaultExecs == null) {
+        this._defaultExecs = this.execs.map(x => x)
+      }
+      this.bindExec(dataShapes, labelShapes, None, reshape = true)
+    }
+  }
+
+  /**
    * Assign, i.e. copy parameters to all the executors.
    * @param argParams A dictionary of name to `NDArray` parameter mapping.
    * @param auxParams A dictionary of name to `NDArray` auxiliary variable mapping.
+   * @param allowExtra hether allow extra parameters that are not needed by symbol.
+   *         If this is True, no error will be thrown when argParams or auxParams
+   *         contain extra parameters that is not needed by the executor.
    */
-  def setParams(argParams: Map[String, NDArray], auxParams: Map[String, NDArray]): Unit = {
-    execs.foreach(_.copyParamsFrom(argParams, auxParams))
+  def setParams(argParams: Map[String, NDArray], auxParams: Map[String, NDArray],
+    allowExtra: Boolean = false): Unit = {
+    execs.foreach(_.copyParamsFrom(argParams, auxParams, allowExtraParams = allowExtra))
   }
 
   /**
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
index 2b1d743..b9cc078 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
@@ -107,11 +107,16 @@ class Module(symbolVar: Symbol,
    * @param allowMissing If true, params could contain missing values,
    *                     and the initializer will be called to fill those missing params.
    * @param forceInit If true, will force re-initialize even if already initialized.
+   * @param allowExtra Whether allow extra parameters that are not needed by symbol.
+   *         If this is True, no error will be thrown when argParams or auxParams
+   *         contain extra parameters that is not needed by the executor.
    */
   override def initParams(initializer: Initializer = new Uniform(0.01f),
                           argParams: Map[String, NDArray] = null,
                           auxParams: Map[String, NDArray] = null,
-                          allowMissing: Boolean = false, forceInit: Boolean = false): Unit = {
+                          allowMissing: Boolean = false,
+                          forceInit: Boolean = false,
+                          allowExtra: Boolean = false): Unit = {
     if (paramsInitialized && !forceInit) {
       return
     }
@@ -141,7 +146,7 @@ class Module(symbolVar: Symbol,
     this.paramsDirty = false
 
     // copy the initialized parameters to devices
-    this.execGroup.setParams(this.argParams, this.auxParams)
+    this.execGroup.setParams(this.argParams, this.auxParams, allowExtra = allowExtra)
   }
 
   // Internal helper for parameter initialization
@@ -262,6 +267,46 @@ class Module(symbolVar: Symbol,
   }
 
   /**
+   * Check that input names matches input data descriptors.
+   */
+  @throws(classOf[IllegalArgumentException])
+  private def _checkNamesMatch(dataNames: IndexedSeq[String], dataShapes: IndexedSeq[DataDesc],
+                        name: String, throwEx: Boolean): Unit = {
+    val actual = dataShapes.map(_.name)
+    if (dataNames.sorted != actual.sorted) {
+      val msg = s"Data provided by ${name}_shapes don't match names specified by " +
+        s"${name}_names (${dataShapes.mkString(", ")} vs. ${dataNames.mkString(", ")})"
+      if (throwEx) throw new IllegalArgumentException(msg)
+      else logger.warn(msg)
+    }
+  }
+
+  /**
+   * parse data_attrs into DataDesc format and check that names match
+   */
+  @throws(classOf[IllegalArgumentException])
+  private def _parseDataDesc(dataNames: IndexedSeq[String], labelNames: IndexedSeq[String],
+                      dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]]):
+    (IndexedSeq[DataDesc], Option[IndexedSeq[DataDesc]]) = {
+    _checkNamesMatch(dataNames, dataShapes, "data", true)
+    if (labelShapes != None) _checkNamesMatch(labelNames, labelShapes.get, "label", false)
+    (dataShapes, labelShapes)
+  }
+
+  /**
+   * Reshapes the module for new input shapes.
+   * @param dataShapes Typically is `dataIter.provideData`.
+   * @param labelShapes Typically is `dataIter.provideLabel`.
+   */
+  def reshape(dataShapes: IndexedSeq[DataDesc],
+              labelShapes: Option[IndexedSeq[DataDesc]] = None): Unit = {
+    require(this.binded)
+    val (tdataShapes, tlabelShapes) = this._parseDataDesc(
+      this.dataNames, this.labelNames, dataShapes, labelShapes)
+    this.execGroup.reshape(tdataShapes, tlabelShapes)
+  }
+
+  /**
    * Install and initialize optimizers.
    * @param kvstore
    * @param optimizer
@@ -344,6 +389,26 @@ class Module(symbolVar: Symbol,
    */
   def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit = {
     require(binded && paramsInitialized)
+    val currDataShapes = this.dataShapes.map(_.shape)
+    val newDataShapes = dataBatch.data.map(_.shape)
+    if (currDataShapes != newDataShapes) {
+      val newDShapes: IndexedSeq[DataDesc] =
+        if (dataBatch.provideData != null) dataBatch.provideData
+        else {
+          this.dataShapes.zip(newDataShapes).map { case (i, shape) =>
+            DataDesc(i.name, shape, i.dtype, i.layout)
+          }
+        }
+      val newLShapes: Option[IndexedSeq[DataDesc]] =
+        if (dataBatch.provideLabel != null) Some(dataBatch.provideLabel)
+        else if (dataBatch.label != null && dataBatch.label.length > 0
+            && this.labelShapes != null) {
+          Some(this.labelShapes.zip(dataBatch.label).map { case (i, j) =>
+            DataDesc(i.name, j.shape, i.dtype, i.layout)
+          })
+        } else None
+      this.reshape(newDShapes, newLShapes)
+    }
     execGroup.forward(dataBatch, isTrain)
   }
 
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/SequentialModule.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/SequentialModule.scala
index dfa63eb..a77041d 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/SequentialModule.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/SequentialModule.scala
@@ -144,11 +144,16 @@ class SequentialModule extends BaseModule {
    * @param allowMissing If true, params could contain missing values,
    *                     and the initializer will be called to fill those missing params.
    * @param forceInit If true, will force re-initialize even if already initialized.
+   * @param allowExtra Whether allow extra parameters that are not needed by symbol.
+   *         If this is True, no error will be thrown when argParams or auxParams
+   *         contain extra parameters that is not needed by the executor.
    */
   override def initParams(initializer: Initializer = new Uniform(0.01f),
                           argParams: Map[String, NDArray] = null,
                           auxParams: Map[String, NDArray] = null,
-                          allowMissing: Boolean = false, forceInit: Boolean = false): Unit = {
+                          allowMissing: Boolean = false,
+                          forceInit: Boolean = false,
+                          allowExtra: Boolean = false): Unit = {
     if (this.paramsInitialized && !forceInit) {
       return
     }
@@ -156,7 +161,8 @@ class SequentialModule extends BaseModule {
 
     for (module <- this.modules) {
       module.initParams(initializer = initializer, argParams = argParams,
-          auxParams = auxParams, allowMissing = allowMissing, forceInit = forceInit)
+          auxParams = auxParams, allowMissing = allowMissing,
+          forceInit = forceInit, allowExtra = allowExtra)
     }
 
     // Internal function to help checking duplicated names,
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala
new file mode 100644
index 0000000..ab48ef7
--- /dev/null
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala
@@ -0,0 +1,368 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import ml.dmlc.mxnet.CheckUtils._
+import ml.dmlc.mxnet.module._
+import ml.dmlc.mxnet.optimizer._
+import ml.dmlc.mxnet.io._
+
+class ModuleSuite extends FunSuite with BeforeAndAfterAll {
+  test ("model dtype") {
+    val dType = DType.Float16
+    val dShape = Shape(3, 8, 7)
+
+    var sym = Symbol.Variable("data")
+    sym = Symbol.Activation(attr = Map("__layout__" -> "TNC"))()(
+      Map("data" -> sym, "act_type" -> "relu"))
+
+    val mod = new Module(sym, IndexedSeq("data"), null,
+      contexts = Array(Context.cpu(0), Context.cpu(1)))
+    mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, dType, "TNC")))
+    mod.initParams()
+    mod.forward(new DataBatch(
+      data = IndexedSeq(NDArray.ones(dShape, dtype = dType)),
+      label = null, index = null, pad = 0))
+    mod.backward(Array(NDArray.ones(dShape, dtype = dType)))
+
+    assert(mod.getOutputs.flatten.forall(_.dtype == dType))
+  }
+
+  test ("module input_grads") {
+    val a = Symbol.Variable("a", kwargs = Map("__layout__" -> "NC"))
+    val b = Symbol.Variable("b", kwargs = Map("__layout__" -> "NC"))
+    var c = Symbol.Variable("c", kwargs = Map("__layout__" -> "NC"))
+
+    import SymbolConversions._
+    c = a + 2 * b + 3 * c
+
+    val mod = new Module(c, IndexedSeq("b", "c", "a"), null,
+      contexts = Array(Context.cpu(0), Context.cpu(1)))
+    mod.bind(dataShapes = IndexedSeq(
+      DataDesc("b", Shape(5, 5)),
+      DataDesc("c", Shape(5, 5)),
+      DataDesc("a", Shape(5, 5))),
+      inputsNeedGrad = true
+    )
+    mod.initParams()
+    mod.forward(new DataBatch(
+      data = IndexedSeq(
+        NDArray.ones(5, 5), NDArray.ones(5, 5), NDArray.ones(5, 5)),
+      label = null, index = null, pad = 0))
+    mod.backward(Array(NDArray.ones(5, 5)))
+
+    val inputGrads = mod.getInputGradsMerged()
+    val aGrad = inputGrads(0).toArray
+    val bGrad = inputGrads(1).toArray
+    val cGrad = inputGrads(2).toArray
+
+    assert(aGrad.forall(_ == 1f))
+    assert(bGrad.forall(_ == 2f))
+    assert(cGrad.forall(_ == 3f))
+  }
+
+  test ("module layout") {
+    var sym = Symbol.Variable("data")
+    sym = Symbol.Activation(attr = Map("__layout__" -> "TNC"))()(
+      Map("data" -> sym, "act_type" -> "relu"))
+
+    val dShape = Shape(3, 8, 7)
+    val mod = new Module(sym, IndexedSeq("data"), null,
+      contexts = Array(Context.cpu(0), Context.cpu(1)))
+    mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = "TNC")))
+    mod.initParams()
+    mod.forward(new DataBatch(
+      data = IndexedSeq(NDArray.ones(dShape)),
+      label = null, index = null, pad = 0))
+    mod.backward(Array(NDArray.ones(dShape)))
+    assert(mod.getOutputsMerged()(0).shape == dShape)
+
+    val hdShape = Shape(3, 4, 7)
+    for (x <- mod.getOutputs) assert(x(0).shape == hdShape)
+  }
+
+  test ("save load") {
+    def mapEqu(a: Map[String, NDArray], b: Map[String, NDArray]): Unit = {
+      assert(a.toSet == b.toSet)
+      for (k <- a.keys) assert(a(k) == b(k))
+    }
+
+    var sym = Symbol.Variable("data")
+    sym = Symbol.FullyConnected()()(Map("data" -> sym, "num_hidden" -> 100))
+
+    // single device
+    var mod = new Module(sym, IndexedSeq("data"), null)
+    mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+    mod.initParams()
+    mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
+    mod.update()
+    mod.saveCheckpoint("test", 0, saveOptStates = true)
+
+    var mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true)
+    mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+    mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
+    assert(mod.getSymbol.toJson == mod2.getSymbol.toJson)
+    mapEqu(mod.getParams._1, mod2.getParams._1)
+
+    // multi device
+    mod = new Module(sym, IndexedSeq("data"), null,
+      contexts = Array(Context.cpu(0), Context.cpu(1)))
+    mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+    mod.initParams()
+    mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
+    mod.update()
+    mod.saveCheckpoint("test", 0, saveOptStates = true)
+
+    mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true)
+    mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+    mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
+    assert(mod.getSymbol.toJson == mod2.getSymbol.toJson)
+    mapEqu(mod.getParams._1, mod2.getParams._1)
+  }
+
+  test ("module reshape") {
+    var sym = Symbol.Variable("data")
+    sym = Symbol.FullyConnected("fc")()(Map("data" -> sym, "num_hidden" -> 20))
+
+    var dShape = Shape(7, 20)
+    val mod = new Module(sym, IndexedSeq("data"), null,
+      contexts = Array(Context.cpu(0), Context.cpu(1)))
+    mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape)))
+    mod.initParams()
+    mod.initOptimizer(optimizer = new SGD(learningRate = 1f))
+
+    mod.forward(new DataBatch(
+      data = IndexedSeq(NDArray.ones(dShape)),
+      label = null, index = null, pad = 0))
+    mod.backward(Array(NDArray.ones(dShape)))
+    mod.update()
+    assert(mod.getOutputsMerged()(0).shape == dShape)
+    assert(mod.getParams._1("fc_bias").toArray.forall(_ == -1f))
+
+    dShape = Shape(14, 20)
+    mod.reshape(IndexedSeq(DataDesc("data", dShape)))
+    mod.forward(new DataBatch(
+      data = IndexedSeq(NDArray.ones(dShape)),
+      label = null, index = null, pad = 0))
+    mod.backward(Array(NDArray.ones(dShape)))
+    mod.update()
+    assert(mod.getOutputsMerged()(0).shape == dShape)
+    assert(mod.getParams._1("fc_bias").toArray.forall(x => (x - -3f) < 1e-3))
+  }
+
+  test ("module setParams") {
+    val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 2))
+    val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 2))
+    val trainData = new NDArrayIter(
+      IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label")
+
+    // symbols
+    var x = Symbol.Variable("data")
+    x = Symbol.FullyConnected(name = "fc_0")()(Map("data" -> x, "num_hidden" -> 2))
+    x = Symbol.Activation(name = "act_0")()(Map("data" -> x, "act_type" -> "sigmoid"))
+    x = Symbol.FullyConnected(name = "fc_1")()(Map("data" -> x, "num_hidden" -> 2))
+    x = Symbol.Activation(name = "act_1")()(Map("data" -> x, "act_type" -> "sigmoid"))
+    x = Symbol.LinearRegressionOutput(name = "softmax")()(Map("data" -> x, "grad_scale" -> 2))
+
+    // create module
+    val mod = new Module(x, contexts = Array(Context.cpu()))
+    mod.bind(dataShapes = trainData.provideData,
+      Option(trainData.provideLabel))
+    val argParamsCorrect = Map(
+      "fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 2)),
+      "fc_0_bias" -> NDArray.array(Array(0.35f, 0.35f), Shape(2)),
+      "fc_1_weight" -> NDArray.array(Array(0.4f, 0.45f, 0.5f, 0.55f), Shape(2, 2)),
+      "fc_1_bias" -> NDArray.array(Array(0.6f, 0.6f), Shape(2))
+    )
+    val argParamsMissing = Map(
+      "fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 2)),
+      "fc_0_bias" -> NDArray.array(Array(0.35f, 0.35f), Shape(2)),
+      "fc_1_weight" -> NDArray.array(Array(0.4f, 0.45f, 0.5f, 0.55f), Shape(2, 2))
+    )
+    val argParamsExtra = Map(
+      "fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 2)),
+      "fc_0_bias" -> NDArray.array(Array(0.35f, 0.35f), Shape(2)),
+      "fc_1_weight" -> NDArray.array(Array(0.4f, 0.45f, 0.5f, 0.55f), Shape(2, 2)),
+      "fc_1_bias" -> NDArray.array(Array(0.6f, 0.6f), Shape(2)),
+      "fc_2_weight" -> NDArray.array(Array(0.6f, 0.6f), Shape(2))
+    )
+
+    mod.setParams(forceInit = true, argParams = argParamsCorrect,
+      auxParams = null)
+
+    // test allow missing
+    mod.setParams(forceInit = true, argParams = argParamsMissing,
+      auxParams = null, allowMissing = true)
+
+    // test allow extra
+    mod.setParams(forceInit = true, argParams = argParamsExtra, auxParams = null,
+      allowMissing = true, allowExtra = true)
+  }
+
+  test ("monitor") {
+    // data iter
+    val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 2))
+    val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 2))
+    val trainData = new NDArrayIter(
+      IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label")
+
+    // symbols
+    var x = Symbol.Variable("data")
+    x = Symbol.FullyConnected(name = "fc_0")()(Map("data" -> x, "num_hidden" -> 2))
+    x = Symbol.Activation(name = "act_0")()(Map("data" -> x, "act_type" -> "sigmoid"))
+    x = Symbol.FullyConnected(name = "fc_1")()(Map("data" -> x, "num_hidden" -> 2))
+    x = Symbol.Activation(name = "act_1")()(Map("data" -> x, "act_type" -> "sigmoid"))
+    x = Symbol.LinearRegressionOutput(name = "softmax")()(Map("data" -> x, "grad_scale" -> 2))
+
+    // create monitor
+    def meanAbs(x: NDArray): NDArray = {
+      val sumAbs = NDArray.sum(NDArray.abs(x))
+      sumAbs / x.shape.product
+    }
+    val mon = new Monitor(1, statFunc = meanAbs)
+
+    // create module
+    val mod = new Module(x, contexts = Array(Context.cpu()))
+    mod.bind(dataShapes = trainData.provideData,
+      Option(trainData.provideLabel))
+    mod.installMonitor(mon)
+    val argParams = Map(
+      "fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 2)),
+      "fc_0_bias" -> NDArray.array(Array(0.35f, 0.35f), Shape(2)),
+      "fc_1_weight" -> NDArray.array(Array(0.4f, 0.45f, 0.5f, 0.55f), Shape(2, 2)),
+      "fc_1_bias" -> NDArray.array(Array(0.6f, 0.6f), Shape(2))
+    )
+    mod.initParams(argParams = argParams)
+
+    val dataBatch = trainData.next()
+    mon.tic()
+    mod.forwardBackward(dataBatch)
+    val res = mon.toc()
+    val keys = Array("act_0", "act_1", "data", "fc_0", "fc_1", "softmax")
+    val monResultCounts = Array(0, 0, 0, 0, 0, 0)
+    assert(res.length == 21)
+    for ((n, k, v) <- res) {
+      var break = false
+      for ((key, idx) <- keys.zipWithIndex) {
+        if (!break && k.startsWith(key)) {
+          monResultCounts(idx) += 1
+          break = true
+        }
+      }
+    }
+    assert(monResultCounts.zip(Array(2, 2, 1, 6, 6, 4)).forall(x => x._1 == x._2))
+  }
+
+  test ("forward reshape") {
+    val numClass = 10
+    val data1 = Symbol.Variable("data1")
+    val data2 = Symbol.Variable("data2")
+    val conv1 = Symbol.Convolution()()(Map("data" -> data1,
+        "kernel" -> "(2, 2)", "num_filter" -> 2, "stride" -> "(2, 2)"))
+    val conv2 = Symbol.Convolution()()(Map("data" -> data2,
+        "kernel" -> "(3, 3)", "num_filter" -> 3, "stride" -> "(1, 1)"))
+    val pooling1 = Symbol.Pooling()()(Map("data" -> conv1,
+        "kernel" -> "(2, 2)", "pool_type" -> "avg", "stride" -> "(1, 1)"))
+    val pooling2 = Symbol.Pooling()()(Map("data" -> conv2,
+        "kernel" -> "(2, 2)", "pool_type" -> "max", "stride" -> "(1, 1)"))
+    val flatten1 = Symbol.flatten()()(Map("data" -> pooling1))
+    val flatten2 = Symbol.flatten()()(Map("data" -> pooling2))
+    val sum = Symbol.sum()()(Map("data" -> flatten1, "axis" -> 1)) +
+      Symbol.sum()()(Map("data" -> flatten2, "axis" -> 1))
+    val fc = Symbol.FullyConnected()()(
+      Map("data" -> sum, "num_hidden" -> numClass))
+    val sym = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc))
+
+    var dShape1 = Shape(10, 3, 64, 64)
+    var dShape2 = Shape(10, 3, 32, 32)
+    var lShape = Shape(10)
+
+    val mod = new Module(sym, IndexedSeq("data1", "data2"))
+    mod.bind(dataShapes = IndexedSeq(
+      DataDesc("data1", dShape1), DataDesc("data2", dShape2)),
+      labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape)))
+    )
+    mod.initParams()
+    mod.initOptimizer(optimizer = new SGD(learningRate = 0.01f))
+
+    // Train with original data shapes
+    var dataBatch = new DataBatch(
+      data = IndexedSeq(
+        NDArray.random_uniform(Map("low" -> 0, "high" -> 9, "shape" -> dShape1.toString()))(),
+        NDArray.random_uniform(Map("low" -> 5, "high" -> 15, "shape" -> dShape2.toString()))()),
+      label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0)
+    mod.forward(dataBatch)
+    assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass))
+    mod.backward()
+    mod.update()
+
+    dShape1 = Shape(3, 3, 64, 64)
+    dShape2 = Shape(3, 3, 32, 32)
+    lShape = Shape(3)
+    dataBatch = new DataBatch(
+      data = IndexedSeq(
+        NDArray.random_uniform(Map("low" -> 0, "high" -> 9, "shape" -> dShape1.toString()))(),
+        NDArray.random_uniform(Map("low" -> 5, "high" -> 15, "shape" -> dShape2.toString()))()),
+      label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0)
+    mod.forward(dataBatch)
+    assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass))
+    mod.backward()
+    mod.update()
+
+    dShape1 = Shape(20, 3, 64, 64)
+    dShape2 = Shape(20, 3, 32, 32)
+    lShape = Shape(20)
+    dataBatch = new DataBatch(
+      data = IndexedSeq(
+        NDArray.random_uniform(Map("low" -> 3, "high" -> 5, "shape" -> dShape1.toString()))(),
+        NDArray.random_uniform(Map("low" -> 10, "high" -> 25, "shape" -> dShape2.toString()))()),
+      label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0)
+    mod.forward(dataBatch)
+    assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass))
+    mod.backward()
+    mod.update()
+
+    // Train with both different batch size and data shapes
+    dShape1 = Shape(20, 3, 120, 120)
+    dShape2 = Shape(20, 3, 32, 64)
+    lShape = Shape(20)
+    dataBatch = new DataBatch(
+      data = IndexedSeq(
+        NDArray.random_uniform(Map("low" -> 0, "high" -> 9, "shape" -> dShape1.toString()))(),
+        NDArray.random_uniform(Map("low" -> 5, "high" -> 15, "shape" -> dShape2.toString()))()),
+      label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0)
+    mod.forward(dataBatch)
+    assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass))
+    mod.backward()
+    mod.update()
+
+    dShape1 = Shape(5, 3, 28, 40)
+    dShape2 = Shape(5, 3, 24, 16)
+    lShape = Shape(5)
+    dataBatch = new DataBatch(
+      data = IndexedSeq(
+        NDArray.random_uniform(Map("low" -> 0, "high" -> 9, "shape" -> dShape1.toString()))(),
+        NDArray.random_uniform(Map("low" -> 15, "high" -> 25, "shape" -> dShape2.toString()))()),
+      label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0)
+    mod.forward(dataBatch)
+    assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass))
+    mod.backward()
+    mod.update()
+  }
+}
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala
index 187869c..ac1cee2 100644
--- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala
@@ -239,7 +239,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll
       var exe = x.simpleBind(ctx = Context.cpu(), gradReq = "write", shapeDict = Map())
       exe.forward(isTrain = false)
       assert(exe.gradArrays.length == 0)
-      assert(CheckUtils.reldiff(result.toArray, exe.outputs.head.toArray) <= 1e-5f)
+      assert(CheckUtils.reldiff(result.toArray, exe.outputs.head.toArray) <= 1e-4f)
     }
   }
 

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].