You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/10/23 23:49:53 UTC
[incubator-mxnet] branch master updated: use ResourceScope in
Model/Trainer/FeedForward.scala (#12882)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 6b4df85 use ResourceScope in Model/Trainer/FeedForward.scala (#12882)
6b4df85 is described below
commit 6b4df8576e373ff68b4fcd99ae6318ddb4b9ed12
Author: Naveen Swamy <mn...@gmail.com>
AuthorDate: Tue Oct 23 16:49:37 2018 -0700
use ResourceScope in Model/Trainer/FeedForward.scala (#12882)
* use ResourceScope in Model/Trainer/FeedForward.scala
* add moveToOuterScope public method to move resources to a outerScope if it exists
* fix memory leak in FeedForward.scala by making it a native resource and disposing argparams, auxParams
in dispose() method
---
.../main/scala/org/apache/mxnet/FeedForward.scala | 152 +++++++++++++--------
.../scala/org/apache/mxnet/NativeResource.scala | 8 +-
.../scala/org/apache/mxnet/ResourceScope.scala | 35 +++--
.../imclassification/TrainModel.scala | 80 +++++------
.../imclassification/util/Trainer.scala | 133 +++++++++---------
5 files changed, 230 insertions(+), 178 deletions(-)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
index 00a1450..2ed9d8c 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
@@ -17,9 +17,10 @@
package org.apache.mxnet
+import org.apache.mxnet.Base.CPtrAddress
import org.apache.mxnet.io.NDArrayIter
import org.apache.mxnet.optimizer.SGD
-import org.slf4j.{LoggerFactory, Logger}
+import org.slf4j.{Logger, LoggerFactory}
import scala.collection.mutable.ListBuffer
@@ -55,7 +56,7 @@ class FeedForward private(
argParams: Map[String, NDArray],
auxParams: Map[String, NDArray],
private val allowExtraParams: Boolean,
- val beginEpoch: Int) {
+ val beginEpoch: Int) extends NativeResource {
val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward])
private var argumentChecked = false
@@ -126,6 +127,8 @@ class FeedForward private(
}
// Initialize weight parameters and auxiliary states
+ // The NDArrays associated with the _argParms and _auxParams are not disposed instead
+ // they are passed a outer scope if available.
private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false)
: (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = {
val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes)
@@ -137,16 +140,26 @@ class FeedForward private(
val paramNameShapes = (argNames zip argShapes).filter { case (name, _) =>
paramNames.contains(name)
}
- val argParams = paramNameShapes.map { case (name, shape) =>
- (name, NDArray.zeros(shape))
+ val argParams = paramNameShapes.map { case (name, shape) => {
+ val param = NDArray.zeros(shape)
+ val curScope = ResourceScope.getCurrentScope()
+ if (curScope.isDefined) curScope.get.moveToOuterScope(param)
+ (name, param)
+ }
}.toMap
- val auxParams = (auxNames zip auxShapes).map { case (name, shape) =>
- (name, NDArray.zeros(shape))
+
+ val auxParams = (auxNames zip auxShapes).map { case (name, shape) => {
+ val param = NDArray.zeros(shape)
+ val curScope = ResourceScope.getCurrentScope()
+ if (curScope.isDefined) curScope.get.moveToOuterScope(param)
+ (name, param)
+ }
}.toMap
for ((k, v) <- argParams) {
if (_argParams != null && _argParams.contains(k) && (!overwrite)) {
argParams(k).set(_argParams(k))
+
} else {
initializer(k, v)
}
@@ -277,13 +290,15 @@ class FeedForward private(
def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric, kvStoreType: String,
epochEndCallback: EpochEndCallback, batchEndCallback: BatchEndCallback,
logger: Logger, workLoadList: Seq[Float]): Unit = {
- // init params first to allow kv store use _argParams to decide its type
- initSymbolParams(trainData)
- // create kvstore
- val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType, ctx.length, _argParams)
- fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
- epochEndCallback, batchEndCallback, logger, workLoadList)
- kvStore.foreach(_.dispose())
+ ResourceScope.using() {
+ // init params first to allow kv store use _argParams to decide its type
+ initSymbolParams(trainData)
+ // create kvstore
+ val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType, ctx.length, _argParams)
+ fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
+ epochEndCallback, batchEndCallback, logger, workLoadList)
+// kvStore.foreach(_.dispose())
+ }
}
def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric,
@@ -313,11 +328,13 @@ class FeedForward private(
batchEndCallback: BatchEndCallback, logger: Logger,
workLoadList: Seq[Float]): Unit = {
// init params first to allow kv store use _argParams to decide its type
- initSymbolParams(trainData)
- // create kvstore
- val (kvStore, updateOnKVStore) = Model.createKVStore(kv)
- fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
- epochEndCallback, batchEndCallback, logger, workLoadList)
+ ResourceScope.using() {
+ initSymbolParams(trainData)
+ // create kvstore
+ val (kvStore, updateOnKVStore) = Model.createKVStore(kv)
+ fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
+ epochEndCallback, batchEndCallback, logger, workLoadList)
+ }
}
def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric,
@@ -352,44 +369,49 @@ class FeedForward private(
batchEndCallback: BatchEndCallback = null, logger: Logger = FeedForward.logger,
workLoadList: Seq[Float] = null): Unit = {
require(evalMetric != null, "evalMetric cannot be null")
- val (argNames, paramNames, auxNames) = initSymbolParams(trainData)
-
- // init optimizer
- val batchSizeMultiplier = kvStore.map { kv =>
- if (kv.`type` == "dist_sync") {
- kv.numWorkers
- } else {
- 1
- }
- }
- val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1)
- this.optimizer.setArgNames(argNames)
- this.optimizer.setRescaleGrad(1f / batchSize)
- this.optimizer.setSymbol(this.symbol)
- val paramIdx2Name =
- if (updateOnKVStore) {
- paramNames.zipWithIndex.map { case (name, idx) => idx -> name }.toMap
- } else {
- paramNames.zipWithIndex.flatMap { case (name, idx) =>
- (0 until ctx.length).map(k => (idx * ctx.length + k) -> name).toMap
- }.toMap
+ // TODO: https://issues.apache.org/jira/browse/MXNET-1171
+ // this leaks memory, initSymbolParams->initParams is already called which allocates
+ // NDArray in argParams, auxParams and here we are overwriting it by calling again.
+ // PhantomRef should take care of releasing this when GC is called, however we have to
+ // wait for the GC call to happen.
+ val (argNames, paramNames, auxNames) = initSymbolParams(trainData)
+
+ // init optimizer
+ val batchSizeMultiplier = kvStore.map { kv =>
+ if (kv.`type` == "dist_sync") {
+ kv.numWorkers
+ } else {
+ 1
+ }
}
- this.optimizer.setIdx2Name(paramIdx2Name)
-
- logger.debug("Start training on multi-device")
- Model.trainMultiDevice(
- symbol, ctx, argNames, paramNames, auxNames,
- _argParams, _auxParams,
- this.beginEpoch, this.numEpoch,
- this.epochSize, this.optimizer,
- kvStore, updateOnKVStore,
- trainData = trainData, evalData = Option(evalData),
- evalMetric = evalMetric,
- epochEndCallback = Option(epochEndCallback),
- batchEndCallback = Option(batchEndCallback),
- workLoadList = workLoadList,
- monitor = monitor,
- symGen = symGen)
+ val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1)
+ this.optimizer.setArgNames(argNames)
+ this.optimizer.setRescaleGrad(1f / batchSize)
+ this.optimizer.setSymbol(this.symbol)
+ val paramIdx2Name =
+ if (updateOnKVStore) {
+ paramNames.zipWithIndex.map { case (name, idx) => idx -> name }.toMap
+ } else {
+ paramNames.zipWithIndex.flatMap { case (name, idx) =>
+ (0 until ctx.length).map(k => (idx * ctx.length + k) -> name).toMap
+ }.toMap
+ }
+ this.optimizer.setIdx2Name(paramIdx2Name)
+
+ logger.debug("Start training on multi-device")
+ Model.trainMultiDevice(
+ symbol, ctx, argNames, paramNames, auxNames,
+ _argParams, _auxParams,
+ this.beginEpoch, this.numEpoch,
+ this.epochSize, this.optimizer,
+ kvStore, updateOnKVStore,
+ trainData = trainData, evalData = Option(evalData),
+ evalMetric = evalMetric,
+ epochEndCallback = Option(epochEndCallback),
+ batchEndCallback = Option(batchEndCallback),
+ workLoadList = workLoadList,
+ monitor = monitor,
+ symGen = symGen)
}
/**
@@ -416,9 +438,29 @@ class FeedForward private(
def serialize(): Array[Byte] = {
Model.serialize(this.symbol, getArgParams, getAuxParams)
}
+
+ // hack to make the FeedForward.scala work with ResourceScope and
+ // automatically release _argParms and _auxParms
+ override def nativeAddress: CPtrAddress = hashCode()
+
+ override def nativeDeAllocator: CPtrAddress => Int = FeedForward.doNothingDeAllocator
+
+ override val ref: NativeResourceRef = super.register()
+
+ override val bytesAllocated: Long = 0L
+
+ override def dispose(): Unit = {
+ if (!super.isDisposed) {
+ _argParams.foreach { case (_, param) => param.dispose() }
+ _auxParams.foreach { case (_, param) => param.dispose() }
+ }
+ }
}
object FeedForward {
+
+ private def doNothingDeAllocator(dummy: CPtrAddress): Int = 0
+
private val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward])
// Check if name is a data argument.
private def isDataArg(name: String): Boolean = {
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
index 48d4b0c..1806b86 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
@@ -46,7 +46,8 @@ private[mxnet] trait NativeResource
*/
def nativeDeAllocator: (CPtrAddress => Int)
- /** Call NativeResource.register to get the reference
+ /**
+ * Call NativeResource.register to get the reference
*/
val ref: NativeResourceRef
@@ -56,6 +57,7 @@ private[mxnet] trait NativeResource
// intentionally making it a val, so it gets evaluated when defined
val bytesAllocated: Long
+ // this is set and unset by [[ResourceScope.add]] and [[ResourceScope.remove]]
private[mxnet] var scope: Option[ResourceScope] = None
@volatile private var disposed = false
@@ -69,11 +71,11 @@ private[mxnet] trait NativeResource
* using PhantomReference
*/
def register(): NativeResourceRef = {
- scope = ResourceScope.getCurrentScope()
+ val scope = ResourceScope.getCurrentScope()
if (scope.isDefined) scope.get.add(this)
NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated)
- // register with PhantomRef tracking to release incase the objects go
+ // register with PhantomRef tracking to release in case the objects go
// out of reference within scope but are held for long time
NativeResourceRef.register(this, nativeDeAllocator)
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
index 1c5782d..30fe147 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
@@ -58,6 +58,7 @@ class ResourceScope extends AutoCloseable {
*/
def add(resource: NativeResource): Unit = {
resourceQ.+=(resource)
+ resource.scope = Some(this)
}
/**
@@ -67,7 +68,21 @@ class ResourceScope extends AutoCloseable {
*/
def remove(resource: NativeResource): Unit = {
resourceQ.-=(resource)
+ resource.scope = None
}
+
+ /**
+ * Removes from current Scope and moves to outer scope if it exists
+ * @param resource Resource to be moved to an outer scope
+ */
+ def moveToOuterScope(resource: NativeResource): Unit = {
+ val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()
+ if (prevScope.isDefined) {
+ this.remove(resource)
+ prevScope.get.add(resource)
+ } else this.remove(resource)
+ }
+
}
object ResourceScope {
@@ -92,32 +107,22 @@ object ResourceScope {
val curScope = if (scope != null) scope else new ResourceScope()
- val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()
-
@inline def resourceInGeneric(g: scala.collection.Iterable[_]) = {
g.foreach( n =>
n match {
case nRes: NativeResource => {
- removeAndAddToPrevScope(nRes)
+ curScope.moveToOuterScope(nRes)
}
case kv: scala.Tuple2[_, _] => {
- if (kv._1.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
+ if (kv._1.isInstanceOf[NativeResource]) curScope.moveToOuterScope(
kv._1.asInstanceOf[NativeResource])
- if (kv._2.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
+ if (kv._2.isInstanceOf[NativeResource]) curScope.moveToOuterScope(
kv._2.asInstanceOf[NativeResource])
}
}
)
}
- @inline def removeAndAddToPrevScope(r: NativeResource) = {
- curScope.remove(r)
- if (prevScope.isDefined) {
- prevScope.get.add(r)
- r.scope = prevScope
- }
- }
-
@inline def safeAddSuppressed(t: Throwable, suppressed: Throwable): Unit = {
if (!t.isInstanceOf[ControlThrowable]) t.addSuppressed(suppressed)
}
@@ -129,8 +134,8 @@ object ResourceScope {
ret match {
// don't de-allocate if returning any collection that contains NativeResource.
case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric)
- case nRes: NativeResource => removeAndAddToPrevScope(nRes)
- case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => removeAndAddToPrevScope(nd) )
+ case nRes: NativeResource => curScope.moveToOuterScope(nRes)
+ case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => curScope.moveToOuterScope(nd) )
case _ => // do nothing
}
ret
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
index 608e191..f6c283c 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
@@ -43,7 +43,7 @@ object TrainModel {
*/
def test(model: String, dataPath: String, numExamples: Int = 60000,
numEpochs: Int = 10, benchmark: Boolean = false): Float = {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val devs = Array(Context.cpu(0))
val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String]
val (dataLoader, net) = dataLoaderAndModel("mnist", model, dataPath,
@@ -110,44 +110,46 @@ object TrainModel {
val inst = new TrainModel
val parser: CmdLineParser = new CmdLineParser(inst)
try {
- parser.parseArgument(args.toList.asJava)
-
- val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME")
- else inst.dataDir
-
- val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network, dataPath,
- inst.numLayers, inst.numExamples, inst.benchmark)
-
- val devs =
- if (inst.gpus != null) inst.gpus.split(',').map(id => Context.gpu(id.trim.toInt))
- else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
- else Array(Context.cpu(0))
-
- val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String]
- envs.put("DMLC_ROLE", inst.role)
- if (inst.schedulerHost != null) {
- require(inst.schedulerPort > 0, "scheduler port not specified")
- envs.put("DMLC_PS_ROOT_URI", inst.schedulerHost)
- envs.put("DMLC_PS_ROOT_PORT", inst.schedulerPort.toString)
- require(inst.numWorker > 0, "Num of workers must > 0")
- envs.put("DMLC_NUM_WORKER", inst.numWorker.toString)
- require(inst.numServer > 0, "Num of servers must > 0")
- envs.put("DMLC_NUM_SERVER", inst.numServer.toString)
- logger.info("Init PS environments")
- KVStoreServer.init(envs.toMap)
- }
-
- if (inst.role != "worker") {
- logger.info("Start KVStoreServer for scheduler & servers")
- KVStoreServer.start()
- } else {
- Trainer.fit(batchSize = inst.batchSize, numExamples = inst.numExamples, devs = devs,
- network = net, dataLoader = dataLoader,
- kvStore = inst.kvStore, numEpochs = inst.numEpochs,
- modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch,
- lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = inst.lrFactorEpoch,
- monitorSize = inst.monitor)
- logger.info("Finish fit ...")
+ ResourceScope.using() {
+ parser.parseArgument(args.toList.asJava)
+
+ val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME")
+ else inst.dataDir
+
+ val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network, dataPath,
+ inst.numLayers, inst.numExamples, inst.benchmark)
+
+ val devs =
+ if (inst.gpus != null) inst.gpus.split(',').map(id => Context.gpu(id.trim.toInt))
+ else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
+ else Array(Context.cpu(0))
+
+ val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String]
+ envs.put("DMLC_ROLE", inst.role)
+ if (inst.schedulerHost != null) {
+ require(inst.schedulerPort > 0, "scheduler port not specified")
+ envs.put("DMLC_PS_ROOT_URI", inst.schedulerHost)
+ envs.put("DMLC_PS_ROOT_PORT", inst.schedulerPort.toString)
+ require(inst.numWorker > 0, "Num of workers must > 0")
+ envs.put("DMLC_NUM_WORKER", inst.numWorker.toString)
+ require(inst.numServer > 0, "Num of servers must > 0")
+ envs.put("DMLC_NUM_SERVER", inst.numServer.toString)
+ logger.info("Init PS environments")
+ KVStoreServer.init(envs.toMap)
+ }
+
+ if (inst.role != "worker") {
+ logger.info("Start KVStoreServer for scheduler & servers")
+ KVStoreServer.start()
+ } else {
+ Trainer.fit(batchSize = inst.batchSize, numExamples = inst.numExamples, devs = devs,
+ network = net, dataLoader = dataLoader,
+ kvStore = inst.kvStore, numEpochs = inst.numEpochs,
+ modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch,
+ lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = inst.lrFactorEpoch,
+ monitorSize = inst.monitor)
+ logger.info("Finish fit ...")
+ }
}
} catch {
case ex: Exception => {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
index 9a54e58..276816c 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
@@ -50,83 +50,84 @@ object Trainer {
lr: Float = 0.1f, lrFactor: Float = 1f, lrFactorEpoch: Float = 1f,
clipGradient: Float = 0f, monitorSize: Int = -1): Accuracy = {
// kvstore
- var kv = KVStore.create(kvStore)
+ ResourceScope.using() {
+ var kv = KVStore.create(kvStore)
- // load model
- val modelPrefixWithRank =
- if (modelPrefix == null) null
- else modelPrefix + s"-${kv.rank}"
+ // load model
+ val modelPrefixWithRank =
+ if (modelPrefix == null) null
+ else modelPrefix + s"-${kv.rank}"
- val (argParams, auxParams, beginEpoch) =
- if (loadEpoch >= 0) {
- require(modelPrefixWithRank != null)
- val tmp = FeedForward.load(modelPrefix, loadEpoch)
- (tmp.getArgParams, tmp.getAuxParams, loadEpoch)
- } else {
- (null, null, 0)
- }
+ val (argParams, auxParams, beginEpoch) =
+ if (loadEpoch >= 0) {
+ require(modelPrefixWithRank != null)
+ val tmp = FeedForward.load(modelPrefix, loadEpoch)
+ (tmp.getArgParams, tmp.getAuxParams, loadEpoch)
+ } else {
+ (null, null, 0)
+ }
- // save model
- val checkpoint: EpochEndCallback =
- if (modelPrefix == null) null
- else new EpochEndCallback {
- override def invoke(epoch: Int, symbol: Symbol,
- argParams: Map[String, NDArray],
- auxStates: Map[String, NDArray]): Unit = {
- Model.saveCheckpoint(modelPrefix, epoch + 1, symbol, argParams, auxParams)
+ // save model
+ val checkpoint: EpochEndCallback =
+ if (modelPrefix == null) null
+ else new EpochEndCallback {
+ override def invoke(epoch: Int, symbol: Symbol,
+ argParams: Map[String, NDArray],
+ auxStates: Map[String, NDArray]): Unit = {
+ Model.saveCheckpoint(modelPrefix, epoch + 1, symbol, argParams, auxParams)
+ }
}
- }
- // data
- val (train, validation) = dataLoader(batchSize, kv)
+ // data
+ val (train, validation) = dataLoader(batchSize, kv)
- // train
- val epochSize =
- if (kvStore == "dist_sync") numExamples / batchSize / kv.numWorkers
- else numExamples / batchSize
+ // train
+ val epochSize =
+ if (kvStore == "dist_sync") numExamples / batchSize / kv.numWorkers
+ else numExamples / batchSize
- val lrScheduler =
- if (lrFactor < 1f) {
- new FactorScheduler(step = Math.max((epochSize * lrFactorEpoch).toInt, 1),
- factor = lrFactor)
- } else {
- null
- }
- val optimizer: Optimizer = new SGD(learningRate = lr,
- lrScheduler = lrScheduler, clipGradient = clipGradient,
- momentum = 0.9f, wd = 0.00001f)
+ val lrScheduler =
+ if (lrFactor < 1f) {
+ new FactorScheduler(step = Math.max((epochSize * lrFactorEpoch).toInt, 1),
+ factor = lrFactor)
+ } else {
+ null
+ }
+ val optimizer: Optimizer = new SGD(learningRate = lr,
+ lrScheduler = lrScheduler, clipGradient = clipGradient,
+ momentum = 0.9f, wd = 0.00001f)
- // disable kvstore for single device
- if (kv.`type`.contains("local") && (devs.length == 1 || devs(0).deviceType != "gpu")) {
- kv.dispose()
- kv = null
- }
+ // disable kvstore for single device
+ if (kv.`type`.contains("local") && (devs.length == 1 || devs(0).deviceType != "gpu")) {
+ kv.dispose()
+ kv = null
+ }
- val model = new FeedForward(ctx = devs,
- symbol = network,
- numEpoch = numEpochs,
- optimizer = optimizer,
- initializer = new Xavier(factorType = "in", magnitude = 2.34f),
- argParams = argParams,
- auxParams = auxParams,
- beginEpoch = beginEpoch,
- epochSize = epochSize)
- if (monitorSize > 0) {
- model.setMonitor(new Monitor(monitorSize))
- }
- val acc = new Accuracy()
- model.fit(trainData = train,
- evalData = validation,
- evalMetric = acc,
- kvStore = kv,
- batchEndCallback = new Speedometer(batchSize, 50),
- epochEndCallback = checkpoint)
- if (kv != null) {
- kv.dispose()
+ val model = new FeedForward(ctx = devs,
+ symbol = network,
+ numEpoch = numEpochs,
+ optimizer = optimizer,
+ initializer = new Xavier(factorType = "in", magnitude = 2.34f),
+ argParams = argParams,
+ auxParams = auxParams,
+ beginEpoch = beginEpoch,
+ epochSize = epochSize)
+ if (monitorSize > 0) {
+ model.setMonitor(new Monitor(monitorSize))
+ }
+ val acc = new Accuracy()
+ model.fit(trainData = train,
+ evalData = validation,
+ evalMetric = acc,
+ kvStore = kv,
+ batchEndCallback = new Speedometer(batchSize, 50),
+ epochEndCallback = checkpoint)
+ if (kv != null) {
+ kv.dispose()
+ }
+ acc
}
- acc
}
-
// scalastyle:on parameterNum
}