You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by li...@apache.org on 2017/08/04 15:47:00 UTC
[incubator-mxnet] branch master updated: [Scala] Make Module Api
sync with Python interface (#7246)
This is an automated email from the ASF dual-hosted git repository.
liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 14b83fc [Scala] Make Module Api sync with Python interface (#7246)
14b83fc is described below
commit 14b83fccef7b96f8d38d780dbce3d0ef47267934
Author: 梁德澎 <li...@gmail.com>
AuthorDate: Fri Aug 4 23:46:51 2017 +0800
[Scala] Make Module Api sync with Python interface (#7246)
* [Scala] Make Module Api sync with Python interface
* fix
---
.../scala/ml/dmlc/mxnet/module/BaseModule.scala | 38 ++-
.../mxnet/module/DataParallelExecutorGroup.scala | 53 ++-
.../main/scala/ml/dmlc/mxnet/module/Module.scala | 69 +++-
.../ml/dmlc/mxnet/module/SequentialModule.scala | 10 +-
.../src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala | 368 +++++++++++++++++++++
.../test/scala/ml/dmlc/mxnet/OperatorSuite.scala | 2 +-
6 files changed, 514 insertions(+), 26 deletions(-)
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BaseModule.scala
index c1cb91d..0a73e1a 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BaseModule.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BaseModule.scala
@@ -121,6 +121,7 @@ abstract class BaseModule {
private[module] var auxParams: Map[String, NDArray] = null
// High Level API
+ def getSymbol: Symbol = this.symbol
// A convenient function that calls both `forward` and `backward`.
def forwardBackward(dataBatch: DataBatch): Unit = {
@@ -259,7 +260,7 @@ abstract class BaseModule {
/**
* Get parameters, those are potentially copies of the the actual parameters used
* to do computation on the device.
- * @return `(arg_params, aux_params)`, a pair of dictionary of name to value mapping.
+ * @return `(argParams, auxParams)`, a pair of dictionary of name to value mapping.
*/
def getParams: (Map[String, NDArray], Map[String, NDArray])
@@ -267,41 +268,52 @@ abstract class BaseModule {
* Initialize the parameters and auxiliary states.
* @param initializer : Initializer
* Called to initialize parameters if needed.
- * arg_params : dict
+ * argParams : dict
* If not None, should be a dictionary of existing arg_params. Initialization
* will be copied from that.
- * aux_params : dict
+ * auxParams : dict
* If not None, should be a dictionary of existing aux_params. Initialization
* will be copied from that.
- * allow_missing : bool
+ * allowMissing : bool
* If true, params could contain missing values, and the initializer will be
* called to fill those missing params.
- * force_init : bool
+ * forceInit : bool
* If true, will force re-initialize even if already initialized.
+ * allowExtra : bool
+ * Whether allow extra parameters that are not needed by symbol.
+ * If this is True, no error will be thrown when argParams or auxParams
+ * contain extra parameters that is not needed by the executor.
*/
def initParams(initializer: Initializer = new Uniform(0.01f),
argParams: Map[String, NDArray] = null,
auxParams: Map[String, NDArray] = null,
- allowMissing: Boolean = false, forceInit: Boolean = false): Unit
+ allowMissing: Boolean = false,
+ forceInit: Boolean = false,
+ allowExtra: Boolean = false): Unit
/**
* Assign parameter and aux state values.
- * arg_params : dict
+ * argParams : dict
* Dictionary of name to value (`NDArray`) mapping.
- * aux_params : dict
+ * auxParams : dict
* Dictionary of name to value (`NDArray`) mapping.
- * allow_missing : bool
+ * allowMissing : bool
* If true, params could contain missing values, and the initializer will be
* called to fill those missing params.
- * force_init : bool
+ * forceInit : bool
* If true, will force re-initialize even if already initialized.
+ * allowExtra : bool
+ * Whether allow extra parameters that are not needed by symbol.
+ * If this is True, no error will be thrown when argParams or auxParams
+ * contain extra parameters that is not needed by the executor.
*/
def setParams(argParams: Map[String, NDArray],
auxParams: Map[String, NDArray],
allowMissing: Boolean = false,
- forceInit: Boolean = true): Unit = {
- initParams(initializer = null, argParams = argParams, auxParams = auxParams,
- allowMissing = allowMissing, forceInit = forceInit)
+ forceInit: Boolean = true,
+ allowExtra: Boolean = false): Unit = {
+ initParams(initializer = null, argParams, auxParams,
+ allowMissing, forceInit, allowExtra)
}
/**
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/DataParallelExecutorGroup.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/DataParallelExecutorGroup.scala
index 2e724c6..ea78962 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/DataParallelExecutorGroup.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/DataParallelExecutorGroup.scala
@@ -297,6 +297,7 @@ class DataParallelExecutorGroup private[module](
private var batchSize: Int = -1
private var slices: Array[(Int, Int)] = null
+ private var _defaultExecs: Array[Executor] = null
private var execs: Array[Executor] = null
private var dataArrays: Seq[Array[((Int, Int), NDArray)]] = null
private var labelArrays: Option[Seq[Array[((Int, Int), NDArray)]]] = None
@@ -305,8 +306,8 @@ class DataParallelExecutorGroup private[module](
private[module] var auxArrays: IndexedSeq[Array[NDArray]] = null
private var inputGradArrays: IndexedSeq[Array[NDArray]] = null
- private val dataLayouts = decideSlices(dataShapes)
- private val labelLayouts =
+ private var dataLayouts = decideSlices(dataShapes)
+ private var labelLayouts =
// call it to make sure labels has the same batch size as data
if (labelShapes != None) decideSlices(labelShapes.get)
else null
@@ -349,12 +350,30 @@ class DataParallelExecutorGroup private[module](
* @param dataShapes DataDesc for input data.
* @param labelShapes DataDesc for input labels.
* @param sharedGroup
+ * @param reshape
*/
def bindExec(dataShapes: Seq[DataDesc], labelShapes: Option[Seq[DataDesc]],
- sharedGroup: Option[DataParallelExecutorGroup]): Unit = {
- execs = (0 until contexts.length).map(i =>
- bindIthExec(i, dataShapes, labelShapes, sharedGroup)
- ).toArray
+ sharedGroup: Option[DataParallelExecutorGroup], reshape: Boolean = false): Unit = {
+ this.batchSize = -1
+ dataLayouts = decideSlices(dataShapes)
+ labelLayouts = {
+ // call it to make sure labels has the same batch size as data
+ if (labelShapes != None) decideSlices(labelShapes.get)
+ else null
+ }
+ if (reshape) {
+ (0 until contexts.length).foreach { i =>
+ val dataShapesSliced = slicedShape(dataShapes, i, dataLayouts)
+ val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts))
+ val inputShapes
+ = dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape])
+ execs(i) = _defaultExecs(i).reshape(allowUpSizing = true, kwargs = inputShapes)
+ }
+ } else {
+ execs = (0 until contexts.length).map(i =>
+ bindIthExec(i, dataShapes, labelShapes, sharedGroup)
+ ).toArray
+ }
// convenient data structures
dataArrays = dataShapes.map(dataDesc =>
@@ -400,12 +419,30 @@ class DataParallelExecutorGroup private[module](
}
/**
+ * Reshape executors.
+ * @param dataShapes
+ * @param labelShapes
+ */
+ def reshape(dataShapes: Seq[DataDesc], labelShapes: Option[Seq[DataDesc]]): Unit = {
+ if (!(dataShapes == this.dataShapes && labelShapes == this.labelShapes)) {
+ if (this._defaultExecs == null) {
+ this._defaultExecs = this.execs.map(x => x)
+ }
+ this.bindExec(dataShapes, labelShapes, None, reshape = true)
+ }
+ }
+
+ /**
* Assign, i.e. copy parameters to all the executors.
* @param argParams A dictionary of name to `NDArray` parameter mapping.
* @param auxParams A dictionary of name to `NDArray` auxiliary variable mapping.
+ * @param allowExtra hether allow extra parameters that are not needed by symbol.
+ * If this is True, no error will be thrown when argParams or auxParams
+ * contain extra parameters that is not needed by the executor.
*/
- def setParams(argParams: Map[String, NDArray], auxParams: Map[String, NDArray]): Unit = {
- execs.foreach(_.copyParamsFrom(argParams, auxParams))
+ def setParams(argParams: Map[String, NDArray], auxParams: Map[String, NDArray],
+ allowExtra: Boolean = false): Unit = {
+ execs.foreach(_.copyParamsFrom(argParams, auxParams, allowExtraParams = allowExtra))
}
/**
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
index 2b1d743..b9cc078 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
@@ -107,11 +107,16 @@ class Module(symbolVar: Symbol,
* @param allowMissing If true, params could contain missing values,
* and the initializer will be called to fill those missing params.
* @param forceInit If true, will force re-initialize even if already initialized.
+ * @param allowExtra Whether allow extra parameters that are not needed by symbol.
+ * If this is True, no error will be thrown when argParams or auxParams
+ * contain extra parameters that is not needed by the executor.
*/
override def initParams(initializer: Initializer = new Uniform(0.01f),
argParams: Map[String, NDArray] = null,
auxParams: Map[String, NDArray] = null,
- allowMissing: Boolean = false, forceInit: Boolean = false): Unit = {
+ allowMissing: Boolean = false,
+ forceInit: Boolean = false,
+ allowExtra: Boolean = false): Unit = {
if (paramsInitialized && !forceInit) {
return
}
@@ -141,7 +146,7 @@ class Module(symbolVar: Symbol,
this.paramsDirty = false
// copy the initialized parameters to devices
- this.execGroup.setParams(this.argParams, this.auxParams)
+ this.execGroup.setParams(this.argParams, this.auxParams, allowExtra = allowExtra)
}
// Internal helper for parameter initialization
@@ -262,6 +267,46 @@ class Module(symbolVar: Symbol,
}
/**
+ * Check that input names matches input data descriptors.
+ */
+ @throws(classOf[IllegalArgumentException])
+ private def _checkNamesMatch(dataNames: IndexedSeq[String], dataShapes: IndexedSeq[DataDesc],
+ name: String, throwEx: Boolean): Unit = {
+ val actual = dataShapes.map(_.name)
+ if (dataNames.sorted != actual.sorted) {
+ val msg = s"Data provided by ${name}_shapes don't match names specified by " +
+ s"${name}_names (${dataShapes.mkString(", ")} vs. ${dataNames.mkString(", ")})"
+ if (throwEx) throw new IllegalArgumentException(msg)
+ else logger.warn(msg)
+ }
+ }
+
+ /**
+ * parse data_attrs into DataDesc format and check that names match
+ */
+ @throws(classOf[IllegalArgumentException])
+ private def _parseDataDesc(dataNames: IndexedSeq[String], labelNames: IndexedSeq[String],
+ dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]]):
+ (IndexedSeq[DataDesc], Option[IndexedSeq[DataDesc]]) = {
+ _checkNamesMatch(dataNames, dataShapes, "data", true)
+ if (labelShapes != None) _checkNamesMatch(labelNames, labelShapes.get, "label", false)
+ (dataShapes, labelShapes)
+ }
+
+ /**
+ * Reshapes the module for new input shapes.
+ * @param dataShapes Typically is `dataIter.provideData`.
+ * @param labelShapes Typically is `dataIter.provideLabel`.
+ */
+ def reshape(dataShapes: IndexedSeq[DataDesc],
+ labelShapes: Option[IndexedSeq[DataDesc]] = None): Unit = {
+ require(this.binded)
+ val (tdataShapes, tlabelShapes) = this._parseDataDesc(
+ this.dataNames, this.labelNames, dataShapes, labelShapes)
+ this.execGroup.reshape(tdataShapes, tlabelShapes)
+ }
+
+ /**
* Install and initialize optimizers.
* @param kvstore
* @param optimizer
@@ -344,6 +389,26 @@ class Module(symbolVar: Symbol,
*/
def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit = {
require(binded && paramsInitialized)
+ val currDataShapes = this.dataShapes.map(_.shape)
+ val newDataShapes = dataBatch.data.map(_.shape)
+ if (currDataShapes != newDataShapes) {
+ val newDShapes: IndexedSeq[DataDesc] =
+ if (dataBatch.provideData != null) dataBatch.provideData
+ else {
+ this.dataShapes.zip(newDataShapes).map { case (i, shape) =>
+ DataDesc(i.name, shape, i.dtype, i.layout)
+ }
+ }
+ val newLShapes: Option[IndexedSeq[DataDesc]] =
+ if (dataBatch.provideLabel != null) Some(dataBatch.provideLabel)
+ else if (dataBatch.label != null && dataBatch.label.length > 0
+ && this.labelShapes != null) {
+ Some(this.labelShapes.zip(dataBatch.label).map { case (i, j) =>
+ DataDesc(i.name, j.shape, i.dtype, i.layout)
+ })
+ } else None
+ this.reshape(newDShapes, newLShapes)
+ }
execGroup.forward(dataBatch, isTrain)
}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/SequentialModule.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/SequentialModule.scala
index dfa63eb..a77041d 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/SequentialModule.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/SequentialModule.scala
@@ -144,11 +144,16 @@ class SequentialModule extends BaseModule {
* @param allowMissing If true, params could contain missing values,
* and the initializer will be called to fill those missing params.
* @param forceInit If true, will force re-initialize even if already initialized.
+ * @param allowExtra Whether allow extra parameters that are not needed by symbol.
+ * If this is True, no error will be thrown when argParams or auxParams
+ * contain extra parameters that is not needed by the executor.
*/
override def initParams(initializer: Initializer = new Uniform(0.01f),
argParams: Map[String, NDArray] = null,
auxParams: Map[String, NDArray] = null,
- allowMissing: Boolean = false, forceInit: Boolean = false): Unit = {
+ allowMissing: Boolean = false,
+ forceInit: Boolean = false,
+ allowExtra: Boolean = false): Unit = {
if (this.paramsInitialized && !forceInit) {
return
}
@@ -156,7 +161,8 @@ class SequentialModule extends BaseModule {
for (module <- this.modules) {
module.initParams(initializer = initializer, argParams = argParams,
- auxParams = auxParams, allowMissing = allowMissing, forceInit = forceInit)
+ auxParams = auxParams, allowMissing = allowMissing,
+ forceInit = forceInit, allowExtra = allowExtra)
}
// Internal function to help checking duplicated names,
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala
new file mode 100644
index 0000000..ab48ef7
--- /dev/null
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ModuleSuite.scala
@@ -0,0 +1,368 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import ml.dmlc.mxnet.CheckUtils._
+import ml.dmlc.mxnet.module._
+import ml.dmlc.mxnet.optimizer._
+import ml.dmlc.mxnet.io._
+
+class ModuleSuite extends FunSuite with BeforeAndAfterAll {
+ test ("model dtype") {
+ val dType = DType.Float16
+ val dShape = Shape(3, 8, 7)
+
+ var sym = Symbol.Variable("data")
+ sym = Symbol.Activation(attr = Map("__layout__" -> "TNC"))()(
+ Map("data" -> sym, "act_type" -> "relu"))
+
+ val mod = new Module(sym, IndexedSeq("data"), null,
+ contexts = Array(Context.cpu(0), Context.cpu(1)))
+ mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, dType, "TNC")))
+ mod.initParams()
+ mod.forward(new DataBatch(
+ data = IndexedSeq(NDArray.ones(dShape, dtype = dType)),
+ label = null, index = null, pad = 0))
+ mod.backward(Array(NDArray.ones(dShape, dtype = dType)))
+
+ assert(mod.getOutputs.flatten.forall(_.dtype == dType))
+ }
+
+ test ("module input_grads") {
+ val a = Symbol.Variable("a", kwargs = Map("__layout__" -> "NC"))
+ val b = Symbol.Variable("b", kwargs = Map("__layout__" -> "NC"))
+ var c = Symbol.Variable("c", kwargs = Map("__layout__" -> "NC"))
+
+ import SymbolConversions._
+ c = a + 2 * b + 3 * c
+
+ val mod = new Module(c, IndexedSeq("b", "c", "a"), null,
+ contexts = Array(Context.cpu(0), Context.cpu(1)))
+ mod.bind(dataShapes = IndexedSeq(
+ DataDesc("b", Shape(5, 5)),
+ DataDesc("c", Shape(5, 5)),
+ DataDesc("a", Shape(5, 5))),
+ inputsNeedGrad = true
+ )
+ mod.initParams()
+ mod.forward(new DataBatch(
+ data = IndexedSeq(
+ NDArray.ones(5, 5), NDArray.ones(5, 5), NDArray.ones(5, 5)),
+ label = null, index = null, pad = 0))
+ mod.backward(Array(NDArray.ones(5, 5)))
+
+ val inputGrads = mod.getInputGradsMerged()
+ val aGrad = inputGrads(0).toArray
+ val bGrad = inputGrads(1).toArray
+ val cGrad = inputGrads(2).toArray
+
+ assert(aGrad.forall(_ == 1f))
+ assert(bGrad.forall(_ == 2f))
+ assert(cGrad.forall(_ == 3f))
+ }
+
+ test ("module layout") {
+ var sym = Symbol.Variable("data")
+ sym = Symbol.Activation(attr = Map("__layout__" -> "TNC"))()(
+ Map("data" -> sym, "act_type" -> "relu"))
+
+ val dShape = Shape(3, 8, 7)
+ val mod = new Module(sym, IndexedSeq("data"), null,
+ contexts = Array(Context.cpu(0), Context.cpu(1)))
+ mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = "TNC")))
+ mod.initParams()
+ mod.forward(new DataBatch(
+ data = IndexedSeq(NDArray.ones(dShape)),
+ label = null, index = null, pad = 0))
+ mod.backward(Array(NDArray.ones(dShape)))
+ assert(mod.getOutputsMerged()(0).shape == dShape)
+
+ val hdShape = Shape(3, 4, 7)
+ for (x <- mod.getOutputs) assert(x(0).shape == hdShape)
+ }
+
+ test ("save load") {
+ def mapEqu(a: Map[String, NDArray], b: Map[String, NDArray]): Unit = {
+ assert(a.toSet == b.toSet)
+ for (k <- a.keys) assert(a(k) == b(k))
+ }
+
+ var sym = Symbol.Variable("data")
+ sym = Symbol.FullyConnected()()(Map("data" -> sym, "num_hidden" -> 100))
+
+ // single device
+ var mod = new Module(sym, IndexedSeq("data"), null)
+ mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+ mod.initParams()
+ mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
+ mod.update()
+ mod.saveCheckpoint("test", 0, saveOptStates = true)
+
+ var mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true)
+ mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+ mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
+ assert(mod.getSymbol.toJson == mod2.getSymbol.toJson)
+ mapEqu(mod.getParams._1, mod2.getParams._1)
+
+ // multi device
+ mod = new Module(sym, IndexedSeq("data"), null,
+ contexts = Array(Context.cpu(0), Context.cpu(1)))
+ mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+ mod.initParams()
+ mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
+ mod.update()
+ mod.saveCheckpoint("test", 0, saveOptStates = true)
+
+ mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true)
+ mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10))))
+ mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f))
+ assert(mod.getSymbol.toJson == mod2.getSymbol.toJson)
+ mapEqu(mod.getParams._1, mod2.getParams._1)
+ }
+
+ test ("module reshape") {
+ var sym = Symbol.Variable("data")
+ sym = Symbol.FullyConnected("fc")()(Map("data" -> sym, "num_hidden" -> 20))
+
+ var dShape = Shape(7, 20)
+ val mod = new Module(sym, IndexedSeq("data"), null,
+ contexts = Array(Context.cpu(0), Context.cpu(1)))
+ mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape)))
+ mod.initParams()
+ mod.initOptimizer(optimizer = new SGD(learningRate = 1f))
+
+ mod.forward(new DataBatch(
+ data = IndexedSeq(NDArray.ones(dShape)),
+ label = null, index = null, pad = 0))
+ mod.backward(Array(NDArray.ones(dShape)))
+ mod.update()
+ assert(mod.getOutputsMerged()(0).shape == dShape)
+ assert(mod.getParams._1("fc_bias").toArray.forall(_ == -1f))
+
+ dShape = Shape(14, 20)
+ mod.reshape(IndexedSeq(DataDesc("data", dShape)))
+ mod.forward(new DataBatch(
+ data = IndexedSeq(NDArray.ones(dShape)),
+ label = null, index = null, pad = 0))
+ mod.backward(Array(NDArray.ones(dShape)))
+ mod.update()
+ assert(mod.getOutputsMerged()(0).shape == dShape)
+ assert(mod.getParams._1("fc_bias").toArray.forall(x => (x - -3f) < 1e-3))
+ }
+
+ test ("module setParams") {
+ val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 2))
+ val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 2))
+ val trainData = new NDArrayIter(
+ IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label")
+
+ // symbols
+ var x = Symbol.Variable("data")
+ x = Symbol.FullyConnected(name = "fc_0")()(Map("data" -> x, "num_hidden" -> 2))
+ x = Symbol.Activation(name = "act_0")()(Map("data" -> x, "act_type" -> "sigmoid"))
+ x = Symbol.FullyConnected(name = "fc_1")()(Map("data" -> x, "num_hidden" -> 2))
+ x = Symbol.Activation(name = "act_1")()(Map("data" -> x, "act_type" -> "sigmoid"))
+ x = Symbol.LinearRegressionOutput(name = "softmax")()(Map("data" -> x, "grad_scale" -> 2))
+
+ // create module
+ val mod = new Module(x, contexts = Array(Context.cpu()))
+ mod.bind(dataShapes = trainData.provideData,
+ Option(trainData.provideLabel))
+ val argParamsCorrect = Map(
+ "fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 2)),
+ "fc_0_bias" -> NDArray.array(Array(0.35f, 0.35f), Shape(2)),
+ "fc_1_weight" -> NDArray.array(Array(0.4f, 0.45f, 0.5f, 0.55f), Shape(2, 2)),
+ "fc_1_bias" -> NDArray.array(Array(0.6f, 0.6f), Shape(2))
+ )
+ val argParamsMissing = Map(
+ "fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 2)),
+ "fc_0_bias" -> NDArray.array(Array(0.35f, 0.35f), Shape(2)),
+ "fc_1_weight" -> NDArray.array(Array(0.4f, 0.45f, 0.5f, 0.55f), Shape(2, 2))
+ )
+ val argParamsExtra = Map(
+ "fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 2)),
+ "fc_0_bias" -> NDArray.array(Array(0.35f, 0.35f), Shape(2)),
+ "fc_1_weight" -> NDArray.array(Array(0.4f, 0.45f, 0.5f, 0.55f), Shape(2, 2)),
+ "fc_1_bias" -> NDArray.array(Array(0.6f, 0.6f), Shape(2)),
+ "fc_2_weight" -> NDArray.array(Array(0.6f, 0.6f), Shape(2))
+ )
+
+ mod.setParams(forceInit = true, argParams = argParamsCorrect,
+ auxParams = null)
+
+ // test allow missing
+ mod.setParams(forceInit = true, argParams = argParamsMissing,
+ auxParams = null, allowMissing = true)
+
+ // test allow extra
+ mod.setParams(forceInit = true, argParams = argParamsExtra, auxParams = null,
+ allowMissing = true, allowExtra = true)
+ }
+
+ test ("monitor") {
+ // data iter
+ val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 2))
+ val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 2))
+ val trainData = new NDArrayIter(
+ IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label")
+
+ // symbols
+ var x = Symbol.Variable("data")
+ x = Symbol.FullyConnected(name = "fc_0")()(Map("data" -> x, "num_hidden" -> 2))
+ x = Symbol.Activation(name = "act_0")()(Map("data" -> x, "act_type" -> "sigmoid"))
+ x = Symbol.FullyConnected(name = "fc_1")()(Map("data" -> x, "num_hidden" -> 2))
+ x = Symbol.Activation(name = "act_1")()(Map("data" -> x, "act_type" -> "sigmoid"))
+ x = Symbol.LinearRegressionOutput(name = "softmax")()(Map("data" -> x, "grad_scale" -> 2))
+
+ // create monitor
+ def meanAbs(x: NDArray): NDArray = {
+ val sumAbs = NDArray.sum(NDArray.abs(x))
+ sumAbs / x.shape.product
+ }
+ val mon = new Monitor(1, statFunc = meanAbs)
+
+ // create module
+ val mod = new Module(x, contexts = Array(Context.cpu()))
+ mod.bind(dataShapes = trainData.provideData,
+ Option(trainData.provideLabel))
+ mod.installMonitor(mon)
+ val argParams = Map(
+ "fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 2)),
+ "fc_0_bias" -> NDArray.array(Array(0.35f, 0.35f), Shape(2)),
+ "fc_1_weight" -> NDArray.array(Array(0.4f, 0.45f, 0.5f, 0.55f), Shape(2, 2)),
+ "fc_1_bias" -> NDArray.array(Array(0.6f, 0.6f), Shape(2))
+ )
+ mod.initParams(argParams = argParams)
+
+ val dataBatch = trainData.next()
+ mon.tic()
+ mod.forwardBackward(dataBatch)
+ val res = mon.toc()
+ val keys = Array("act_0", "act_1", "data", "fc_0", "fc_1", "softmax")
+ val monResultCounts = Array(0, 0, 0, 0, 0, 0)
+ assert(res.length == 21)
+ for ((n, k, v) <- res) {
+ var break = false
+ for ((key, idx) <- keys.zipWithIndex) {
+ if (!break && k.startsWith(key)) {
+ monResultCounts(idx) += 1
+ break = true
+ }
+ }
+ }
+ assert(monResultCounts.zip(Array(2, 2, 1, 6, 6, 4)).forall(x => x._1 == x._2))
+ }
+
+ test ("forward reshape") {
+ val numClass = 10
+ val data1 = Symbol.Variable("data1")
+ val data2 = Symbol.Variable("data2")
+ val conv1 = Symbol.Convolution()()(Map("data" -> data1,
+ "kernel" -> "(2, 2)", "num_filter" -> 2, "stride" -> "(2, 2)"))
+ val conv2 = Symbol.Convolution()()(Map("data" -> data2,
+ "kernel" -> "(3, 3)", "num_filter" -> 3, "stride" -> "(1, 1)"))
+ val pooling1 = Symbol.Pooling()()(Map("data" -> conv1,
+ "kernel" -> "(2, 2)", "pool_type" -> "avg", "stride" -> "(1, 1)"))
+ val pooling2 = Symbol.Pooling()()(Map("data" -> conv2,
+ "kernel" -> "(2, 2)", "pool_type" -> "max", "stride" -> "(1, 1)"))
+ val flatten1 = Symbol.flatten()()(Map("data" -> pooling1))
+ val flatten2 = Symbol.flatten()()(Map("data" -> pooling2))
+ val sum = Symbol.sum()()(Map("data" -> flatten1, "axis" -> 1)) +
+ Symbol.sum()()(Map("data" -> flatten2, "axis" -> 1))
+ val fc = Symbol.FullyConnected()()(
+ Map("data" -> sum, "num_hidden" -> numClass))
+ val sym = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc))
+
+ var dShape1 = Shape(10, 3, 64, 64)
+ var dShape2 = Shape(10, 3, 32, 32)
+ var lShape = Shape(10)
+
+ val mod = new Module(sym, IndexedSeq("data1", "data2"))
+ mod.bind(dataShapes = IndexedSeq(
+ DataDesc("data1", dShape1), DataDesc("data2", dShape2)),
+ labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape)))
+ )
+ mod.initParams()
+ mod.initOptimizer(optimizer = new SGD(learningRate = 0.01f))
+
+ // Train with original data shapes
+ var dataBatch = new DataBatch(
+ data = IndexedSeq(
+ NDArray.random_uniform(Map("low" -> 0, "high" -> 9, "shape" -> dShape1.toString()))(),
+ NDArray.random_uniform(Map("low" -> 5, "high" -> 15, "shape" -> dShape2.toString()))()),
+ label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0)
+ mod.forward(dataBatch)
+ assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass))
+ mod.backward()
+ mod.update()
+
+ dShape1 = Shape(3, 3, 64, 64)
+ dShape2 = Shape(3, 3, 32, 32)
+ lShape = Shape(3)
+ dataBatch = new DataBatch(
+ data = IndexedSeq(
+ NDArray.random_uniform(Map("low" -> 0, "high" -> 9, "shape" -> dShape1.toString()))(),
+ NDArray.random_uniform(Map("low" -> 5, "high" -> 15, "shape" -> dShape2.toString()))()),
+ label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0)
+ mod.forward(dataBatch)
+ assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass))
+ mod.backward()
+ mod.update()
+
+ dShape1 = Shape(20, 3, 64, 64)
+ dShape2 = Shape(20, 3, 32, 32)
+ lShape = Shape(20)
+ dataBatch = new DataBatch(
+ data = IndexedSeq(
+ NDArray.random_uniform(Map("low" -> 3, "high" -> 5, "shape" -> dShape1.toString()))(),
+ NDArray.random_uniform(Map("low" -> 10, "high" -> 25, "shape" -> dShape2.toString()))()),
+ label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0)
+ mod.forward(dataBatch)
+ assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass))
+ mod.backward()
+ mod.update()
+
+ // Train with both different batch size and data shapes
+ dShape1 = Shape(20, 3, 120, 120)
+ dShape2 = Shape(20, 3, 32, 64)
+ lShape = Shape(20)
+ dataBatch = new DataBatch(
+ data = IndexedSeq(
+ NDArray.random_uniform(Map("low" -> 0, "high" -> 9, "shape" -> dShape1.toString()))(),
+ NDArray.random_uniform(Map("low" -> 5, "high" -> 15, "shape" -> dShape2.toString()))()),
+ label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0)
+ mod.forward(dataBatch)
+ assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass))
+ mod.backward()
+ mod.update()
+
+ dShape1 = Shape(5, 3, 28, 40)
+ dShape2 = Shape(5, 3, 24, 16)
+ lShape = Shape(5)
+ dataBatch = new DataBatch(
+ data = IndexedSeq(
+ NDArray.random_uniform(Map("low" -> 0, "high" -> 9, "shape" -> dShape1.toString()))(),
+ NDArray.random_uniform(Map("low" -> 15, "high" -> 25, "shape" -> dShape2.toString()))()),
+ label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0)
+ mod.forward(dataBatch)
+ assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass))
+ mod.backward()
+ mod.update()
+ }
+}
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala
index 187869c..ac1cee2 100644
--- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala
@@ -239,7 +239,7 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll
var exe = x.simpleBind(ctx = Context.cpu(), gradReq = "write", shapeDict = Map())
exe.forward(isTrain = false)
assert(exe.gradArrays.length == 0)
- assert(CheckUtils.reldiff(result.toArray, exe.outputs.head.toArray) <= 1e-5f)
+ assert(CheckUtils.reldiff(result.toArray, exe.outputs.head.toArray) <= 1e-4f)
}
}
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].